diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index e08d92a..1545adc 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -18,7 +18,7 @@ module Database.Esqueleto -- $gettingstarted -- * @esqueleto@'s Language - Esqueleto( where_, on, orderBy, asc, desc + Esqueleto( where_, on, orderBy, asc, desc, limit, offset , sub_select, sub_selectDistinct, (^.), (?.) , val, isNothing, just, nothing, countRows, not_ , (==.), (>=.), (>.), (<=.), (<.), (!=.), (&&.), (||.) diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 89738bb..de42919 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -33,6 +33,7 @@ module Database.Esqueleto.Internal.Language import Control.Applicative (Applicative(..), (<$>)) import Control.Exception (Exception) +import Data.Int (Int64) import Data.String (IsString) import Data.Typeable (Typeable) import Database.Persist.GenericSql @@ -136,6 +137,12 @@ class (Functor query, Applicative query, Monad query) => -- | Descending order of this field or expression. desc :: PersistField a => expr (Value a) -> expr OrderBy + -- | @LIMIT@. Limit the number of returned rows. + limit :: Int64 -> query () + + -- | @OFFSET@. Usually used with 'limit'. + offset :: Int64 -> query () + -- | Execute a subquery @SELECT@ in an expression. sub_select :: PersistField a => query (expr (Value a)) -> expr (Value a) diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index d8ec2a5..dd8433a 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -31,7 +31,6 @@ module Database.Esqueleto.Internal.Sql , rawExecute , toRawSql , Mode(..) - , Escape , SqlSelect , veryUnsafeCoerceSqlExprValue ) where @@ -44,11 +43,12 @@ import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Logger (MonadLogger) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Resource (MonadResourceBase) +import Data.Int (Int64) import Data.List (intersperse) import Data.Monoid (Monoid(..), (<>)) import Database.Persist.EntityDef import Database.Persist.GenericSql -import Database.Persist.GenericSql.Internal (Connection(escapeName)) +import Database.Persist.GenericSql.Internal (Connection(escapeName, noLimit)) import Database.Persist.GenericSql.Raw (withStmt, execute) import Database.Persist.Store hiding (delete) import qualified Control.Monad.Trans.Reader as R @@ -60,6 +60,7 @@ import qualified Data.HashSet as HS import qualified Data.Text as T import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy.Builder as TLB +import qualified Data.Text.Lazy.Builder.Int as TLBI import Database.Esqueleto.Internal.Language @@ -93,12 +94,13 @@ data SideData = SideData { sdFromClause :: ![FromClause] , sdSetClause :: ![SetClause] , sdWhereClause :: !WhereClause , sdOrderByClause :: ![OrderByClause] + , sdLimitClause :: !LimitClause } instance Monoid SideData where - mempty = SideData mempty mempty mempty mempty - SideData f s w o `mappend` SideData f' s' w' o' = - SideData (f <> f') (s <> s') (w <> w') (o <> o') + mempty = SideData mempty mempty mempty mempty mempty + SideData f s w o l `mappend` SideData f' s' w' o' l' = + SideData (f <> f') (s <> s') (w <> w') (o <> o') (l <> l') -- | A part of a @FROM@ clause. @@ -155,6 +157,18 @@ instance Monoid WhereClause where type OrderByClause = SqlExpr OrderBy +-- | A @LIMIT@ clause. +data LimitClause = Limit (Maybe Int64) (Maybe Int64) + +instance Monoid LimitClause where + mempty = Limit mzero mzero + Limit l1 o1 `mappend` Limit l2 o2 = + Limit (l2 `mplus` l1) (o2 `mplus` o1) + -- More than one 'limit' or 'offset' is issued, we want to + -- keep the latest one. That's why we use mplus with + -- "reversed" arguments. + + ---------------------------------------------------------------------- @@ -193,8 +207,8 @@ newIdentFor = Q . lift . try . unDBName -- | Use an identifier. -useIdent :: Escape -> Ident -> TLB.Builder -useIdent esc (I ident) = esc (DBName ident) +useIdent :: Connection -> Ident -> TLB.Builder +useIdent conn (I ident) = fromDBName conn $ DBName ident ---------------------------------------------------------------------- @@ -204,7 +218,7 @@ useIdent esc (I ident) = esc (DBName ident) data SqlExpr a where EEntity :: Ident -> SqlExpr (Entity val) EMaybe :: SqlExpr a -> SqlExpr (Maybe a) - ERaw :: NeedParens -> (Escape -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) + ERaw :: NeedParens -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) EOrderBy :: OrderByType -> SqlExpr (Value a) -> SqlExpr OrderBy ESet :: (SqlExpr (Entity val) -> SqlExpr (Value ())) -> SqlExpr (Update val) EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a) @@ -217,9 +231,6 @@ parensM Parens = parens data OrderByType = ASC | DESC --- | (Internal) Backend-specific function that escapes a 'DBName'. -type Escape = DBName -> TLB.Builder - instance Esqueleto SqlQuery SqlExpr SqlPersist where fromStart = x @@ -260,11 +271,14 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where asc = EOrderBy ASC desc = EOrderBy DESC + limit n = Q $ W.tell mempty { sdLimitClause = Limit (Just n) Nothing } + offset n = Q $ W.tell mempty { sdLimitClause = Limit Nothing (Just n) } + sub_select = sub SELECT sub_selectDistinct = sub SELECT_DISTINCT EEntity ident ^. field = - ERaw Never $ \esc -> (useIdent esc ident <> ("." <> fieldName esc field), []) + ERaw Never $ \conn -> (useIdent conn ident <> ("." <> fieldName conn field), []) EMaybe r ?. field = maybelize (r ^. field) where @@ -278,8 +292,8 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where nothing = unsafeSqlValue "NULL" countRows = unsafeSqlValue "COUNT(*)" - not_ (ERaw p f) = ERaw Never $ \esc -> let (b, vals) = f esc - in ("NOT " <> parensM p b, vals) + not_ (ERaw p f) = ERaw Never $ \conn -> let (b, vals) = f conn + in ("NOT " <> parensM p b, vals) (==.) = unsafeSqlBinOp " = " (>=.) = unsafeSqlBinOp " >= " @@ -311,18 +325,18 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where fieldName :: (PersistEntity val, PersistField typ) - => Escape -> EntityField val typ -> TLB.Builder -fieldName esc = esc . fieldDB . persistFieldDef + => Connection -> EntityField val typ -> TLB.Builder +fieldName conn = fromDBName conn . fieldDB . persistFieldDef setAux :: (PersistEntity val, PersistField typ) => EntityField val typ -> (SqlExpr (Entity val) -> SqlExpr (Value typ)) -> SqlExpr (Update val) setAux field mkVal = ESet $ \ent -> unsafeSqlBinOp " = " name (mkVal ent) - where name = ERaw Never $ \esc -> (fieldName esc field, mempty) + where name = ERaw Never $ \conn -> (fieldName conn field, mempty) sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Value a)) -> SqlExpr (Value a) -sub mode query = ERaw Parens $ \esc -> first parens (toRawSql mode esc query) +sub mode query = ERaw Parens $ \conn -> first parens (toRawSql mode conn query) fromDBName :: Connection -> DBName -> TLB.Builder fromDBName conn = TLB.fromText . escapeName conn @@ -346,10 +360,10 @@ fromDBName conn = TLB.fromText . escapeName conn unsafeSqlBinOp :: TLB.Builder -> SqlExpr (Value a) -> SqlExpr (Value b) -> SqlExpr (Value c) unsafeSqlBinOp op (ERaw p1 f1) (ERaw p2 f2) = ERaw Parens f where - f esc = let (b1, vals1) = f1 esc - (b2, vals2) = f2 esc - in ( parensM p1 b1 <> op <> parensM p2 b2 - , vals1 <> vals2 ) + f conn = let (b1, vals1) = f1 conn + (b2, vals2) = f2 conn + in ( parensM p1 b1 <> op <> parensM p2 b2 + , vals1 <> vals2 ) {-# INLINE unsafeSqlBinOp #-} @@ -365,9 +379,9 @@ unsafeSqlValue v = ERaw Never $ \_ -> (v, mempty) unsafeSqlFunction :: UnsafeSqlFunctionArgument a => TLB.Builder -> a -> SqlExpr (Value b) unsafeSqlFunction name arg = - ERaw Never $ \esc -> + ERaw Never $ \conn -> let (argsTLB, argsVals) = - uncommas' $ map (\(ERaw _ f) -> f esc) $ toArgList arg + uncommas' $ map (\(ERaw _ f) -> f conn) $ toArgList arg in (name <> parens argsTLB, argsVals) class UnsafeSqlFunctionArgument a where @@ -421,7 +435,7 @@ rawSelectSource mode query = src run conn = uncurry withStmt $ first builderToText $ - toRawSql mode (fromDBName conn) query + toRawSql mode conn query massage = do mrow <- C.await @@ -533,7 +547,7 @@ rawExecute mode query = do conn <- SqlPersist R.ask uncurry execute $ first builderToText $ - toRawSql mode (fromDBName conn) query + toRawSql mode conn query -- | Execute an @esqueleto@ @DELETE@ query inside @persistent@'s @@ -600,18 +614,19 @@ builderToText = TL.toStrict . TLB.toLazyTextWith defaultChunkSize -- @esqueleto@, instead of manually using this function (which is -- possible but tedious), you may just turn on query logging of -- @persistent@. -toRawSql :: SqlSelect a r => Mode -> Escape -> SqlQuery a -> (TLB.Builder, [PersistValue]) -toRawSql mode esc query = - let (ret, SideData fromClauses setClauses whereClauses orderByClauses) = +toRawSql :: SqlSelect a r => Mode -> Connection -> SqlQuery a -> (TLB.Builder, [PersistValue]) +toRawSql mode conn query = + let (ret, SideData fromClauses setClauses whereClauses orderByClauses limitClause) = flip S.evalState initialIdentState $ W.runWriterT $ unQ query in mconcat - [ makeSelect esc mode ret - , makeFrom esc mode fromClauses - , makeSet esc setClauses - , makeWhere esc whereClauses - , makeOrderBy esc orderByClauses + [ makeSelect conn mode ret + , makeFrom conn mode fromClauses + , makeSet conn setClauses + , makeWhere conn whereClauses + , makeOrderBy conn orderByClauses + , makeLimit conn limitClause ] -- | (Internal) Mode of query being converted by 'toRawSql'. @@ -625,8 +640,8 @@ uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a) uncommas' = (uncommas *** mconcat) . unzip -makeSelect :: SqlSelect a r => Escape -> Mode -> a -> (TLB.Builder, [PersistValue]) -makeSelect esc mode ret = first (s <>) (sqlSelectCols esc ret) +makeSelect :: SqlSelect a r => Connection -> Mode -> a -> (TLB.Builder, [PersistValue]) +makeSelect conn mode ret = first (s <>) (sqlSelectCols conn ret) where s = case mode of SELECT -> "SELECT " @@ -635,9 +650,9 @@ makeSelect esc mode ret = first (s <>) (sqlSelectCols esc ret) UPDATE -> "UPDATE " -makeFrom :: Escape -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) -makeFrom _ _ [] = mempty -makeFrom esc mode fs = ret +makeFrom :: Connection -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) +makeFrom _ _ [] = mempty +makeFrom conn mode fs = ret where ret = case collectOnClauses fs of Left expr -> throw $ mkExc expr @@ -658,8 +673,8 @@ makeFrom esc mode fs = ret base ident@(I identText) def = let db@(DBName dbText) = entityDB def in ( if dbText == identText - then esc db - else esc db <> (" AS " <> useIdent esc ident) + then fromDBName conn db + else fromDBName conn db <> (" AS " <> useIdent conn ident) , mempty ) fromKind InnerJoinKind = " INNER JOIN " @@ -668,35 +683,53 @@ makeFrom esc mode fs = ret fromKind RightOuterJoinKind = " RIGHT OUTER JOIN " fromKind FullOuterJoinKind = " FULL OUTER JOIN " - makeOnClause (ERaw _ f) = first (" ON " <>) (f esc) + makeOnClause (ERaw _ f) = first (" ON " <>) (f conn) mkExc :: SqlExpr (Value Bool) -> OnClauseWithoutMatchingJoinException mkExc (ERaw _ f) = OnClauseWithoutMatchingJoinException $ - TL.unpack $ TLB.toLazyText $ fst (f esc) + TL.unpack $ TLB.toLazyText $ fst (f conn) -makeSet :: Escape -> [SetClause] -> (TLB.Builder, [PersistValue]) -makeSet _ [] = mempty -makeSet esc os = first ("\nSET " <>) $ uncommas' (map mk os) +makeSet :: Connection -> [SetClause] -> (TLB.Builder, [PersistValue]) +makeSet _ [] = mempty +makeSet conn os = first ("\nSET " <>) $ uncommas' (map mk os) where - mk (SetClause (ERaw _ f)) = f esc + mk (SetClause (ERaw _ f)) = f conn -makeWhere :: Escape -> WhereClause -> (TLB.Builder, [PersistValue]) -makeWhere _ NoWhere = mempty -makeWhere esc (Where (ERaw _ f)) = first ("\nWHERE " <>) (f esc) +makeWhere :: Connection -> WhereClause -> (TLB.Builder, [PersistValue]) +makeWhere _ NoWhere = mempty +makeWhere conn (Where (ERaw _ f)) = first ("\nWHERE " <>) (f conn) -makeOrderBy :: Escape -> [OrderByClause] -> (TLB.Builder, [PersistValue]) -makeOrderBy _ [] = mempty -makeOrderBy esc os = first ("\nORDER BY " <>) $ uncommas' (map mk os) +makeOrderBy :: Connection -> [OrderByClause] -> (TLB.Builder, [PersistValue]) +makeOrderBy _ [] = mempty +makeOrderBy conn os = first ("\nORDER BY " <>) $ uncommas' (map mk os) where - mk (EOrderBy t (ERaw _ f)) = first (<> orderByType t) (f esc) + mk (EOrderBy t (ERaw _ f)) = first (<> orderByType t) (f conn) orderByType ASC = " ASC" orderByType DESC = " DESC" +makeLimit :: Connection -> LimitClause -> (TLB.Builder, [PersistValue]) +makeLimit _ (Limit Nothing Nothing) = mempty +makeLimit _ (Limit Nothing (Just 0)) = mempty +makeLimit conn (Limit ml mo) = (ret, mempty) + where + ret = TLB.singleton '\n' <> (limitTLB <> offsetTLB) + + limitTLB = + case ml of + Just l -> "LIMIT " <> TLBI.decimal l + Nothing -> TLB.fromText (noLimit conn) + + offsetTLB = + case mo of + Just o -> " OFFSET " <> TLBI.decimal o + Nothing -> mempty + + parens :: TLB.Builder -> TLB.Builder parens b = "(" <> (b <> ")") @@ -714,7 +747,7 @@ class SqlSelect a r | a -> r, r -> a where -- | Creates the variable part of the @SELECT@ query and -- returns the list of 'PersistValue's that will be given to -- 'withStmt'. - sqlSelectCols :: Escape -> a -> (TLB.Builder, [PersistValue]) + sqlSelectCols :: Connection -> a -> (TLB.Builder, [PersistValue]) -- | Number of columns that will be consumed. Must be -- non-strict on the argument. @@ -733,10 +766,10 @@ instance SqlSelect () () where -- | You may return an 'Entity' from a 'select' query. instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where - sqlSelectCols escape expr@(EEntity ident) = ret + sqlSelectCols conn expr@(EEntity ident) = ret where process ed = uncommas $ - map ((name <>) . escape) $ + map ((name <>) . fromDBName conn) $ (entityID ed:) $ map fieldDB $ entityFields ed @@ -746,7 +779,7 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where -- clause), while 'rawSql' assumes that it's just the -- name of the table (which doesn't allow self-joins, for -- example). - name = useIdent escape ident <> "." + name = useIdent conn ident <> "." ret = let ed = entityDef $ getEntityVal expr in (process ed, mempty) sqlSelectColCount = (+1) . length . entityFields . entityDef . getEntityVal @@ -761,7 +794,7 @@ getEntityVal = error "Esqueleto/Sql/getEntityVal" -- | You may return a possibly-@NULL@ 'Entity' from a 'select' query. instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entity a)) where - sqlSelectCols escape (EMaybe ent) = sqlSelectCols escape ent + sqlSelectCols conn (EMaybe ent) = sqlSelectCols conn ent sqlSelectColCount = sqlSelectColCount . fromEMaybe where fromEMaybe :: SqlExpr (Maybe e) -> SqlExpr e diff --git a/test/Test.hs b/test/Test.hs index 066ae76..ba1f90f 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -117,6 +117,50 @@ main = do , (Value (personName p2), Value (personName p1)) , (Value (personName p2), Value (personName p2)) ] + it "works with many kinds of LIMITs and OFFSETs" $ + run $ do + [p1e, p2e, p3e, p4e] <- mapM insert' [p1, p2, p3, p4] + let people = from $ \p -> do + orderBy [asc (p ^. PersonName)] + return p + ret1 <- select $ do + p <- people + limit 2 + limit 1 + return p + liftIO $ ret1 `shouldBe` [ p1e ] + ret2 <- select $ do + p <- people + limit 1 + limit 2 + return p + liftIO $ ret2 `shouldBe` [ p1e, p4e ] + ret3 <- select $ do + p <- people + offset 3 + offset 2 + return p + liftIO $ ret3 `shouldBe` [ p3e, p2e ] + ret4 <- select $ do + p <- people + offset 3 + limit 5 + offset 2 + limit 3 + offset 1 + limit 2 + return p + liftIO $ ret4 `shouldBe` [ p4e, p3e ] + ret5 <- select $ do + p <- people + offset 1000 + limit 1 + limit 1000 + offset 0 + return p + liftIO $ ret5 `shouldBe` [ p1e, p4e, p3e, p2e ] + + describe "select/JOIN" $ do it "works with a LEFT OUTER JOIN" $ run $ do