From 2b5b561f6e782cf9f1f122fcde1eb4db2e6e21e9 Mon Sep 17 00:00:00 2001 From: Ben Levy Date: Sun, 30 Aug 2020 14:15:11 -0500 Subject: [PATCH] Add new SetOperation constructor for parenthesized query (#195) * Add new SetOperation constructor for parenthesized query. Automatically detect when parentheses are needed on SelectQuery usage (only works for MySQL). * Add Parens to SelectQueryP and create a pattern synonym for SelectQuery. SelectQueryP is hidden as end users should only be using SelectQuery. --- src/Database/Esqueleto/Experimental.hs | 30 ++++++++++++++++----- src/Database/Esqueleto/Internal/Internal.hs | 2 ++ test/MySQL/Test.hs | 30 ++++++++++++++++++++- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/Database/Esqueleto/Experimental.hs b/src/Database/Esqueleto/Experimental.hs index bdec319..f431a4b 100644 --- a/src/Database/Esqueleto/Experimental.hs +++ b/src/Database/Esqueleto/Experimental.hs @@ -9,6 +9,7 @@ , TypeFamilies , UndecidableInstances , OverloadedStrings + , PatternSynonyms #-} module Database.Esqueleto.Experimental @@ -23,7 +24,8 @@ module Database.Esqueleto.Experimental -- * Documentation - SqlSetOperation(..) + SqlSetOperation(Union, UnionAll, Except, Intersect) + , pattern SelectQuery , From(..) , on , from @@ -68,6 +70,8 @@ import Database.Esqueleto.Internal.Internal , to3, to4, to5, to6, to7, to8 , from3, from4, from5, from6, from7, from8 , veryUnsafeCoerceSqlExprValue + , parensM + , NeedParens(..) ) import GHC.TypeLits @@ -379,7 +383,10 @@ data SqlSetOperation a = | UnionAll (SqlSetOperation a) (SqlSetOperation a) | Except (SqlSetOperation a) (SqlSetOperation a) | Intersect (SqlSetOperation a) (SqlSetOperation a) - | SelectQuery (SqlQuery a) + | SelectQueryP NeedParens (SqlQuery a) + +pattern SelectQuery :: SqlQuery a -> SqlSetOperation a +pattern SelectQuery q = SelectQueryP Never q -- | Data type that represents the syntax of a 'JOIN' tree. In practice, -- only the @Table@ constructor is used directly when writing queries. For example, @@ -462,7 +469,7 @@ instance {-# OVERLAPPABLE #-} ToFrom (FullOuterJoin a b) where instance (SqlSelect a' r,SqlSelect a'' r', ToAlias a, a' ~ ToAliasT a, ToAliasReference a', ToAliasReferenceT a' ~ a'') => ToFrom (SqlSetOperation a) where -- If someone uses just a plain SelectQuery it should behave like a normal subquery - toFrom (SelectQuery q) = SubQuery q + toFrom (SelectQueryP _ q) = SubQuery q -- Otherwise use the SqlSetOperation toFrom q = SqlSetOperation q @@ -617,12 +624,21 @@ from parts = do where aliasQueries o = case o of - SelectQuery q -> do + SelectQueryP p q -> do (ret, sideData) <- Q $ W.censor (\_ -> mempty) $ W.listen $ unQ q prevState <- Q $ lift S.get aliasedRet <- toAlias ret Q $ lift $ S.put prevState - pure (SelectQuery $ Q $ W.WriterT $ pure (aliasedRet, sideData), aliasedRet) + let p' = + case p of + Parens -> Parens + Never -> + if (sdLimitClause sideData) /= mempty + || length (sdOrderByClause sideData) > 0 then + Parens + else + Never + pure (SelectQueryP p' $ Q $ W.WriterT $ pure (aliasedRet, sideData), aliasedRet) Union o1 o2 -> do (o1', ret) <- aliasQueries o1 (o2', _ ) <- aliasQueries o2 @@ -642,7 +658,9 @@ from parts = do operationToSql o info = case o of - SelectQuery q -> toRawSql SELECT info q + SelectQueryP p q -> + let (builder, values) = toRawSql SELECT info q + in (parensM p builder, values) Union o1 o2 -> doSetOperation "UNION" info o1 o2 UnionAll o1 o2 -> doSetOperation "UNION ALL" info o1 o2 Except o1 o2 -> doSetOperation "EXCEPT" info o1 o2 diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index 5e89f68..1722ba1 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -1886,6 +1886,7 @@ type OrderByClause = SqlExpr OrderBy -- | A @LIMIT@ clause. data LimitClause = Limit (Maybe Int64) (Maybe Int64) + deriving Eq instance Semigroup LimitClause where Limit l1 o1 <> Limit l2 o2 = @@ -2042,6 +2043,7 @@ data SqlExpr a where data InsertFinal data NeedParens = Parens | Never + deriving Eq parensM :: NeedParens -> TLB.Builder -> TLB.Builder parensM Never = id diff --git a/test/MySQL/Test.hs b/test/MySQL/Test.hs index 3cfd89d..0350845 100644 --- a/test/MySQL/Test.hs +++ b/test/MySQL/Test.hs @@ -2,6 +2,7 @@ , FlexibleContexts , RankNTypes , TypeFamilies + , TypeApplications #-} module Main (main) where @@ -17,13 +18,14 @@ import Database.Persist.MySQL ( withMySQLConn , connectPassword , defaultConnectInfo) import Database.Esqueleto +import Database.Esqueleto.Experimental hiding (from, on) +import qualified Database.Esqueleto.Experimental as Experimental import qualified Control.Monad.Trans.Resource as R import Test.Hspec import Common.Test - -- testMysqlRandom :: Spec -- testMysqlRandom = do -- -- This is known not to work until @@ -162,6 +164,31 @@ testMysqlTextFunctions = do nameContains like "iv" [p4e] +testMysqlUnionWithLimits :: Spec +testMysqlUnionWithLimits = do + describe "MySQL Union" $ do + it "supports limit/orderBy by parenthesizing" $ do + run $ do + mapM_ (insert . Foo) [1..6] + + let q1 = do + foo <- Experimental.from $ Table @Foo + where_ $ foo ^. FooName <=. val 3 + orderBy [asc $ foo ^. FooName] + limit 2 + pure $ foo ^. FooName + + let q2 = do + foo <- Experimental.from $ Table @Foo + where_ $ foo ^. FooName >. val 3 + orderBy [asc $ foo ^. FooName] + limit 2 + pure $ foo ^. FooName + + + ret <- select $ Experimental.from $ SelectQuery q1 `Union` SelectQuery q2 + liftIO $ ret `shouldMatchList` [Value 1, Value 2, Value 4, Value 5] + main :: IO () main = do @@ -180,6 +207,7 @@ main = do testMysqlCoalesce testMysqlUpdate testMysqlTextFunctions + testMysqlUnionWithLimits