From 51c08ed0e8c9bf350914f7ac3cf7e3217affd44f Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Tue, 12 Feb 2013 13:09:30 -0200 Subject: [PATCH] Add deleteCount and updateCount (fixes #10). --- esqueleto.cabal | 2 +- src/Database/Esqueleto.hs | 2 ++ src/Database/Esqueleto/Internal/Sql.hs | 31 +++++++++++++++++++++----- test/Test.hs | 25 +++++++++++++-------- 4 files changed, 44 insertions(+), 16 deletions(-) diff --git a/esqueleto.cabal b/esqueleto.cabal index d6521a0..c6902f2 100644 --- a/esqueleto.cabal +++ b/esqueleto.cabal @@ -55,7 +55,7 @@ library build-depends: base >= 4.5 && < 4.7 , text == 0.11.* - , persistent == 1.1.* + , persistent >= 1.1.5 && < 1.2 , transformers >= 0.2 , unordered-containers >= 0.2 diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index ffe3e6d..fc8b021 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -48,7 +48,9 @@ module Database.Esqueleto , selectSource , selectDistinctSource , delete + , deleteCount , update + , updateCount -- * Helpers , valkey diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 032c9f0..28ea389 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -20,7 +20,9 @@ module Database.Esqueleto.Internal.Sql , selectDistinct , selectDistinctSource , delete + , deleteCount , update + , updateCount -- * The guts , unsafeSqlBinOp , unsafeSqlValue @@ -38,7 +40,7 @@ module Database.Esqueleto.Internal.Sql import Control.Applicative (Applicative(..), (<$>), (<$)) import Control.Arrow ((***), first) import Control.Exception (throw, throwIO) -import Control.Monad ((>=>), ap, MonadPlus(..)) +import Control.Monad ((>=>), ap, void, MonadPlus(..)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Logger (MonadLogger) import Control.Monad.Trans.Class (lift) @@ -49,7 +51,7 @@ import Data.Monoid (Monoid(..), (<>)) import Database.Persist.EntityDef import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal (Connection(escapeName, noLimit)) -import Database.Persist.GenericSql.Raw (execute, SqlBackend, withStmt) +import Database.Persist.GenericSql.Raw (executeCount, SqlBackend, withStmt) import Database.Persist.Store hiding (delete) import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.State as S @@ -589,10 +591,10 @@ rawExecute :: ( MonadLogger m , MonadResourceBase m ) => Mode -> SqlQuery () - -> SqlPersist m () + -> SqlPersist m Int64 rawExecute mode query = do conn <- SqlPersist R.ask - uncurry execute $ + uncurry executeCount $ first builderToText $ toRawSql mode conn query @@ -623,7 +625,15 @@ delete :: ( MonadLogger m , MonadResourceBase m ) => SqlQuery () -> SqlPersist m () -delete = rawExecute DELETE +delete = void . deleteCount + + +-- | Same as 'delete', but returns the number of rows affected. +deleteCount :: ( MonadLogger m + , MonadResourceBase m ) + => SqlQuery () + -> SqlPersist m Int64 +deleteCount = rawExecute DELETE -- | Execute an @esqueleto@ @UPDATE@ query inside @persistent@'s @@ -643,7 +653,16 @@ update :: ( MonadLogger m , SqlEntity val ) => (SqlExpr (Entity val) -> SqlQuery ()) -> SqlPersist m () -update = rawExecute UPDATE . from +update = void . updateCount + + +-- | Same as 'update', but returns the number of rows affected. +updateCount :: ( MonadLogger m + , MonadResourceBase m + , SqlEntity val ) + => (SqlExpr (Entity val) -> SqlQuery ()) + -> SqlPersist m Int64 +updateCount = rawExecute UPDATE . from ---------------------------------------------------------------------- diff --git a/test/Test.hs b/test/Test.hs index 414bf92..532b434 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -434,19 +434,22 @@ main = do p1e <- insert' p1 p2e <- insert' p2 p3e <- insert' p3 - ret1 <- select $ - from $ \p -> do - orderBy [asc (p ^. PersonName)] - return p + let getAll = select $ + from $ \p -> do + orderBy [asc (p ^. PersonName)] + return p + ret1 <- getAll liftIO $ ret1 `shouldBe` [ p1e, p3e, p2e ] () <- delete $ from $ \p -> where_ (p ^. PersonName ==. val (personName p1)) - ret2 <- select $ - from $ \p -> do - orderBy [asc (p ^. PersonName)] - return p + ret2 <- getAll liftIO $ ret2 `shouldBe` [ p3e, p2e ] + n <- deleteCount $ + from $ \p -> + return ((p :: SqlExpr (Entity Person)) `seq` ()) + ret3 <- getAll + liftIO $ (n, ret3) `shouldBe` (2, []) describe "update" $ do it "works on a simple example" $ @@ -459,12 +462,16 @@ main = do set p [ PersonName =. val anon , PersonAge *=. just (val 2) ] where_ (p ^. PersonName !=. val "Mike") + n <- updateCount $ \p -> do + set p [ PersonAge +=. just (val 1) ] + where_ (p ^. PersonName !=. val "Mike") ret <- select $ from $ \p -> do orderBy [ asc (p ^. PersonName), asc (p ^. PersonAge) ] return p + liftIO $ n `shouldBe` 2 liftIO $ ret `shouldBe` [ Entity p2k (Person anon Nothing) - , Entity p1k (Person anon (Just 72)) + , Entity p1k (Person anon (Just 73)) , Entity p3k p3 ] it "works with a subexpression having COUNT(*)" $