Add new SetOperation constructor for parenthesized query (#195)

* Add new SetOperation constructor for parenthesized query. Automatically detect when parentheses are needed on SelectQuery usage (only works for MySQL).

* Add Parens to SelectQueryP and create a pattern synonym for SelectQuery. SelectQueryP is hidden as end users should only be using SelectQuery.
This commit is contained in:
Ben Levy 2020-08-30 14:15:11 -05:00 committed by GitHub
parent dd16400d64
commit 2b5b561f6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 7 deletions

View File

@ -9,6 +9,7 @@
, TypeFamilies , TypeFamilies
, UndecidableInstances , UndecidableInstances
, OverloadedStrings , OverloadedStrings
, PatternSynonyms
#-} #-}
module Database.Esqueleto.Experimental module Database.Esqueleto.Experimental
@ -23,7 +24,8 @@ module Database.Esqueleto.Experimental
-- * Documentation -- * Documentation
SqlSetOperation(..) SqlSetOperation(Union, UnionAll, Except, Intersect)
, pattern SelectQuery
, From(..) , From(..)
, on , on
, from , from
@ -68,6 +70,8 @@ import Database.Esqueleto.Internal.Internal
, to3, to4, to5, to6, to7, to8 , to3, to4, to5, to6, to7, to8
, from3, from4, from5, from6, from7, from8 , from3, from4, from5, from6, from7, from8
, veryUnsafeCoerceSqlExprValue , veryUnsafeCoerceSqlExprValue
, parensM
, NeedParens(..)
) )
import GHC.TypeLits import GHC.TypeLits
@ -379,7 +383,10 @@ data SqlSetOperation a =
| UnionAll (SqlSetOperation a) (SqlSetOperation a) | UnionAll (SqlSetOperation a) (SqlSetOperation a)
| Except (SqlSetOperation a) (SqlSetOperation a) | Except (SqlSetOperation a) (SqlSetOperation a)
| Intersect (SqlSetOperation a) (SqlSetOperation a) | Intersect (SqlSetOperation a) (SqlSetOperation a)
| SelectQuery (SqlQuery a) | SelectQueryP NeedParens (SqlQuery a)
pattern SelectQuery :: SqlQuery a -> SqlSetOperation a
pattern SelectQuery q = SelectQueryP Never q
-- | Data type that represents the syntax of a 'JOIN' tree. In practice, -- | Data type that represents the syntax of a 'JOIN' tree. In practice,
-- only the @Table@ constructor is used directly when writing queries. For example, -- only the @Table@ constructor is used directly when writing queries. For example,
@ -462,7 +469,7 @@ instance {-# OVERLAPPABLE #-} ToFrom (FullOuterJoin a b) where
instance (SqlSelect a' r,SqlSelect a'' r', ToAlias a, a' ~ ToAliasT a, ToAliasReference a', ToAliasReferenceT a' ~ a'') => ToFrom (SqlSetOperation a) where instance (SqlSelect a' r,SqlSelect a'' r', ToAlias a, a' ~ ToAliasT a, ToAliasReference a', ToAliasReferenceT a' ~ a'') => ToFrom (SqlSetOperation a) where
-- If someone uses just a plain SelectQuery it should behave like a normal subquery -- If someone uses just a plain SelectQuery it should behave like a normal subquery
toFrom (SelectQuery q) = SubQuery q toFrom (SelectQueryP _ q) = SubQuery q
-- Otherwise use the SqlSetOperation -- Otherwise use the SqlSetOperation
toFrom q = SqlSetOperation q toFrom q = SqlSetOperation q
@ -617,12 +624,21 @@ from parts = do
where where
aliasQueries o = aliasQueries o =
case o of case o of
SelectQuery q -> do SelectQueryP p q -> do
(ret, sideData) <- Q $ W.censor (\_ -> mempty) $ W.listen $ unQ q (ret, sideData) <- Q $ W.censor (\_ -> mempty) $ W.listen $ unQ q
prevState <- Q $ lift S.get prevState <- Q $ lift S.get
aliasedRet <- toAlias ret aliasedRet <- toAlias ret
Q $ lift $ S.put prevState Q $ lift $ S.put prevState
pure (SelectQuery $ Q $ W.WriterT $ pure (aliasedRet, sideData), aliasedRet) let p' =
case p of
Parens -> Parens
Never ->
if (sdLimitClause sideData) /= mempty
|| length (sdOrderByClause sideData) > 0 then
Parens
else
Never
pure (SelectQueryP p' $ Q $ W.WriterT $ pure (aliasedRet, sideData), aliasedRet)
Union o1 o2 -> do Union o1 o2 -> do
(o1', ret) <- aliasQueries o1 (o1', ret) <- aliasQueries o1
(o2', _ ) <- aliasQueries o2 (o2', _ ) <- aliasQueries o2
@ -642,7 +658,9 @@ from parts = do
operationToSql o info = operationToSql o info =
case o of case o of
SelectQuery q -> toRawSql SELECT info q SelectQueryP p q ->
let (builder, values) = toRawSql SELECT info q
in (parensM p builder, values)
Union o1 o2 -> doSetOperation "UNION" info o1 o2 Union o1 o2 -> doSetOperation "UNION" info o1 o2
UnionAll o1 o2 -> doSetOperation "UNION ALL" info o1 o2 UnionAll o1 o2 -> doSetOperation "UNION ALL" info o1 o2
Except o1 o2 -> doSetOperation "EXCEPT" info o1 o2 Except o1 o2 -> doSetOperation "EXCEPT" info o1 o2

View File

@ -1886,6 +1886,7 @@ type OrderByClause = SqlExpr OrderBy
-- | A @LIMIT@ clause. -- | A @LIMIT@ clause.
data LimitClause = Limit (Maybe Int64) (Maybe Int64) data LimitClause = Limit (Maybe Int64) (Maybe Int64)
deriving Eq
instance Semigroup LimitClause where instance Semigroup LimitClause where
Limit l1 o1 <> Limit l2 o2 = Limit l1 o1 <> Limit l2 o2 =
@ -2042,6 +2043,7 @@ data SqlExpr a where
data InsertFinal data InsertFinal
data NeedParens = Parens | Never data NeedParens = Parens | Never
deriving Eq
parensM :: NeedParens -> TLB.Builder -> TLB.Builder parensM :: NeedParens -> TLB.Builder -> TLB.Builder
parensM Never = id parensM Never = id

View File

@ -2,6 +2,7 @@
, FlexibleContexts , FlexibleContexts
, RankNTypes , RankNTypes
, TypeFamilies , TypeFamilies
, TypeApplications
#-} #-}
module Main (main) where module Main (main) where
@ -17,13 +18,14 @@ import Database.Persist.MySQL ( withMySQLConn
, connectPassword , connectPassword
, defaultConnectInfo) , defaultConnectInfo)
import Database.Esqueleto import Database.Esqueleto
import Database.Esqueleto.Experimental hiding (from, on)
import qualified Database.Esqueleto.Experimental as Experimental
import qualified Control.Monad.Trans.Resource as R import qualified Control.Monad.Trans.Resource as R
import Test.Hspec import Test.Hspec
import Common.Test import Common.Test
-- testMysqlRandom :: Spec -- testMysqlRandom :: Spec
-- testMysqlRandom = do -- testMysqlRandom = do
-- -- This is known not to work until -- -- This is known not to work until
@ -162,6 +164,31 @@ testMysqlTextFunctions = do
nameContains like "iv" [p4e] nameContains like "iv" [p4e]
testMysqlUnionWithLimits :: Spec
testMysqlUnionWithLimits = do
describe "MySQL Union" $ do
it "supports limit/orderBy by parenthesizing" $ do
run $ do
mapM_ (insert . Foo) [1..6]
let q1 = do
foo <- Experimental.from $ Table @Foo
where_ $ foo ^. FooName <=. val 3
orderBy [asc $ foo ^. FooName]
limit 2
pure $ foo ^. FooName
let q2 = do
foo <- Experimental.from $ Table @Foo
where_ $ foo ^. FooName >. val 3
orderBy [asc $ foo ^. FooName]
limit 2
pure $ foo ^. FooName
ret <- select $ Experimental.from $ SelectQuery q1 `Union` SelectQuery q2
liftIO $ ret `shouldMatchList` [Value 1, Value 2, Value 4, Value 5]
main :: IO () main :: IO ()
main = do main = do
@ -180,6 +207,7 @@ main = do
testMysqlCoalesce testMysqlCoalesce
testMysqlUpdate testMysqlUpdate
testMysqlTextFunctions testMysqlTextFunctions
testMysqlUnionWithLimits