diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 3681b87..e21dfd6 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -41,8 +41,9 @@ module Database.Esqueleto Esqueleto( where_, on, groupBy, orderBy, rand, asc, desc, limit, offset , distinct, distinctOn, don, distinctOnOrderBy, having, locking , sub_select, sub_selectDistinct, (^.), (?.) - , val, isNothing, just, nothing, joinV, countRows, count, not_ - , (==.), (>=.), (>.), (<=.), (<.), (!=.), (&&.), (||.) + , val, isNothing, just, nothing, joinV + , countRows, count, countDistinct + , not_, (==.), (>=.), (>.), (<=.), (<.), (!=.), (&&.), (||.) , (+.), (-.), (/.), (*.) , random_, round_, ceiling_, floor_ , min_, max_, sum_, avg_, castNum, castNumM diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index b9a0751..fae72a8 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -327,7 +327,12 @@ class (Functor query, Applicative query, Monad query) => countRows :: Num a => expr (Value a) -- | @COUNT@. - count :: (Num a) => expr (Value typ) -> expr (Value a) + count :: Num a => expr (Value typ) -> expr (Value a) + + -- | @COUNT(DISTINCT x)@. + -- + -- /Since: 2.4.1/ + countDistinct :: Num a => expr (Value typ) -> expr (Value a) not_ :: expr (Value Bool) -> expr (Value Bool) diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index d32ecb1..df29f54 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -55,15 +55,15 @@ import Control.Exception (throw, throwIO) import Control.Monad (ap, MonadPlus(..), liftM) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (lift) -import qualified Control.Monad.Trans.Reader as R +import Control.Monad.Trans.Resource (MonadResource) +import Data.Acquire (with, allocateAcquire, Acquire) import Data.Int (Int64) import Data.List (intersperse) import Data.Monoid (Last(..), Monoid(..), (<>)) import Data.Proxy (Proxy(..)) import Database.Esqueleto.Internal.PersistentImport -import Database.Persist.Sql.Util ( - entityColumnNames, entityColumnCount, parseEntityValues, isIdField - , hasCompositeKey) +import Database.Persist.Sql.Util (entityColumnNames, entityColumnCount, parseEntityValues, isIdField, hasCompositeKey) +import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.State as S import qualified Control.Monad.Trans.Writer as W import qualified Data.Conduit as C @@ -72,8 +72,6 @@ import qualified Data.HashSet as HS import qualified Data.Text as T import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy.Builder as TLB -import Data.Acquire (with, allocateAcquire, Acquire) -import Control.Monad.Trans.Resource (MonadResource) import Database.Esqueleto.Internal.Language @@ -444,10 +442,9 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where nothing = unsafeSqlValue "NULL" joinV (ERaw p f) = ERaw p f joinV (ECompositeKey f) = ECompositeKey f - countRows = unsafeSqlValue "COUNT(*)" - count (ERaw _ f) = ERaw Never $ \info -> let (b, vals) = f info - in ("COUNT" <> parens b, vals) - count (ECompositeKey _) = unsafeSqlValue "COUNT(*)" -- Assumes no NULLs on a PK + countRows = unsafeSqlValue "COUNT(*)" + count = countHelper "" "" + countDistinct = countHelper "(DISTINCT " ")" not_ (ERaw p f) = ERaw Never $ \info -> let (b, vals) = f info in ("NOT " <> parensM p b, vals) @@ -558,6 +555,10 @@ ifNotEmptyList :: SqlExpr (ValueList a) -> Bool -> SqlExpr (Value Bool) -> SqlEx ifNotEmptyList EEmptyList b _ = val b ifNotEmptyList (EList _) _ x = x +countHelper :: Num a => TLB.Builder -> TLB.Builder -> SqlExpr (Value typ) -> SqlExpr (Value a) +countHelper open close (ERaw _ f) = ERaw Never $ first (\b -> "COUNT" <> open <> parens b <> close) . f +countHelper _ _ (ECompositeKey _) = countRows -- Assumes no NULLs on a PK + ---------------------------------------------------------------------- diff --git a/test/Test.hs b/test/Test.hs index acf2c96..aca3741 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -21,7 +21,7 @@ module Main (main) where import Control.Applicative ((<$>)) import Control.Arrow ((&&&)) import Control.Exception (IOException) -import Control.Monad (replicateM, replicateM_, void) +import Control.Monad (forM_, replicateM, replicateM_, void) import Control.Monad.IO.Class (MonadIO(liftIO)) import Control.Monad.Logger (MonadLogger(..), runStderrLoggingT, runNoLoggingT) import Control.Monad.Trans.Control (MonadBaseControl(..)) @@ -1330,6 +1330,23 @@ main = do it "looks sane for ForShare" $ sanityCheck ForShare "FOR SHARE" it "looks sane for LockInShareMode" $ sanityCheck LockInShareMode "LOCK IN SHARE MODE" + describe "counting rows" $ do + forM_ [ ("count (test A)", count . (^. PersonAge), 4) + , ("count (test B)", count . (^. PersonWeight), 5) + , ("countRows", const countRows, 5) + , ("countDistinct", countDistinct . (^. PersonAge), 2) ] $ + \(title, countKind, expected) -> + it (title ++ " works as expected") $ + run $ do + mapM_ insert + [ Person "" (Just 1) (Just 1) 1 + , Person "" (Just 2) (Just 1) 1 + , Person "" (Just 2) (Just 1) 1 + , Person "" (Just 2) (Just 2) 1 + , Person "" Nothing (Just 3) 1] + [Value n] <- select $ from $ return . countKind + liftIO $ (n :: Int) `shouldBe` expected + describe "PostgreSQL module" $ do it "should be tested on the PostgreSQL database" $ #if !defined(WITH_POSTGRESQL)