Add deleteCount and updateCount (fixes #10).

This commit is contained in:
Felipe Lessa 2013-02-12 13:09:30 -02:00
parent 53823aeb12
commit 51c08ed0e8
4 changed files with 44 additions and 16 deletions

View File

@ -55,7 +55,7 @@ library
build-depends: build-depends:
base >= 4.5 && < 4.7 base >= 4.5 && < 4.7
, text == 0.11.* , text == 0.11.*
, persistent == 1.1.* , persistent >= 1.1.5 && < 1.2
, transformers >= 0.2 , transformers >= 0.2
, unordered-containers >= 0.2 , unordered-containers >= 0.2

View File

@ -48,7 +48,9 @@ module Database.Esqueleto
, selectSource , selectSource
, selectDistinctSource , selectDistinctSource
, delete , delete
, deleteCount
, update , update
, updateCount
-- * Helpers -- * Helpers
, valkey , valkey

View File

@ -20,7 +20,9 @@ module Database.Esqueleto.Internal.Sql
, selectDistinct , selectDistinct
, selectDistinctSource , selectDistinctSource
, delete , delete
, deleteCount
, update , update
, updateCount
-- * The guts -- * The guts
, unsafeSqlBinOp , unsafeSqlBinOp
, unsafeSqlValue , unsafeSqlValue
@ -38,7 +40,7 @@ module Database.Esqueleto.Internal.Sql
import Control.Applicative (Applicative(..), (<$>), (<$)) import Control.Applicative (Applicative(..), (<$>), (<$))
import Control.Arrow ((***), first) import Control.Arrow ((***), first)
import Control.Exception (throw, throwIO) import Control.Exception (throw, throwIO)
import Control.Monad ((>=>), ap, MonadPlus(..)) import Control.Monad ((>=>), ap, void, MonadPlus(..))
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Logger (MonadLogger) import Control.Monad.Logger (MonadLogger)
import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Class (lift)
@ -49,7 +51,7 @@ import Data.Monoid (Monoid(..), (<>))
import Database.Persist.EntityDef import Database.Persist.EntityDef
import Database.Persist.GenericSql import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal (Connection(escapeName, noLimit)) import Database.Persist.GenericSql.Internal (Connection(escapeName, noLimit))
import Database.Persist.GenericSql.Raw (execute, SqlBackend, withStmt) import Database.Persist.GenericSql.Raw (executeCount, SqlBackend, withStmt)
import Database.Persist.Store hiding (delete) import Database.Persist.Store hiding (delete)
import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.State as S import qualified Control.Monad.Trans.State as S
@ -589,10 +591,10 @@ rawExecute :: ( MonadLogger m
, MonadResourceBase m ) , MonadResourceBase m )
=> Mode => Mode
-> SqlQuery () -> SqlQuery ()
-> SqlPersist m () -> SqlPersist m Int64
rawExecute mode query = do rawExecute mode query = do
conn <- SqlPersist R.ask conn <- SqlPersist R.ask
uncurry execute $ uncurry executeCount $
first builderToText $ first builderToText $
toRawSql mode conn query toRawSql mode conn query
@ -623,7 +625,15 @@ delete :: ( MonadLogger m
, MonadResourceBase m ) , MonadResourceBase m )
=> SqlQuery () => SqlQuery ()
-> SqlPersist m () -> SqlPersist m ()
delete = rawExecute DELETE delete = void . deleteCount
-- | Same as 'delete', but returns the number of rows affected.
deleteCount :: ( MonadLogger m
, MonadResourceBase m )
=> SqlQuery ()
-> SqlPersist m Int64
deleteCount = rawExecute DELETE
-- | Execute an @esqueleto@ @UPDATE@ query inside @persistent@'s -- | Execute an @esqueleto@ @UPDATE@ query inside @persistent@'s
@ -643,7 +653,16 @@ update :: ( MonadLogger m
, SqlEntity val ) , SqlEntity val )
=> (SqlExpr (Entity val) -> SqlQuery ()) => (SqlExpr (Entity val) -> SqlQuery ())
-> SqlPersist m () -> SqlPersist m ()
update = rawExecute UPDATE . from update = void . updateCount
-- | Same as 'update', but returns the number of rows affected.
updateCount :: ( MonadLogger m
, MonadResourceBase m
, SqlEntity val )
=> (SqlExpr (Entity val) -> SqlQuery ())
-> SqlPersist m Int64
updateCount = rawExecute UPDATE . from
---------------------------------------------------------------------- ----------------------------------------------------------------------

View File

@ -434,19 +434,22 @@ main = do
p1e <- insert' p1 p1e <- insert' p1
p2e <- insert' p2 p2e <- insert' p2
p3e <- insert' p3 p3e <- insert' p3
ret1 <- select $ let getAll = select $
from $ \p -> do from $ \p -> do
orderBy [asc (p ^. PersonName)] orderBy [asc (p ^. PersonName)]
return p return p
ret1 <- getAll
liftIO $ ret1 `shouldBe` [ p1e, p3e, p2e ] liftIO $ ret1 `shouldBe` [ p1e, p3e, p2e ]
() <- delete $ () <- delete $
from $ \p -> from $ \p ->
where_ (p ^. PersonName ==. val (personName p1)) where_ (p ^. PersonName ==. val (personName p1))
ret2 <- select $ ret2 <- getAll
from $ \p -> do
orderBy [asc (p ^. PersonName)]
return p
liftIO $ ret2 `shouldBe` [ p3e, p2e ] liftIO $ ret2 `shouldBe` [ p3e, p2e ]
n <- deleteCount $
from $ \p ->
return ((p :: SqlExpr (Entity Person)) `seq` ())
ret3 <- getAll
liftIO $ (n, ret3) `shouldBe` (2, [])
describe "update" $ do describe "update" $ do
it "works on a simple example" $ it "works on a simple example" $
@ -459,12 +462,16 @@ main = do
set p [ PersonName =. val anon set p [ PersonName =. val anon
, PersonAge *=. just (val 2) ] , PersonAge *=. just (val 2) ]
where_ (p ^. PersonName !=. val "Mike") where_ (p ^. PersonName !=. val "Mike")
n <- updateCount $ \p -> do
set p [ PersonAge +=. just (val 1) ]
where_ (p ^. PersonName !=. val "Mike")
ret <- select $ ret <- select $
from $ \p -> do from $ \p -> do
orderBy [ asc (p ^. PersonName), asc (p ^. PersonAge) ] orderBy [ asc (p ^. PersonName), asc (p ^. PersonAge) ]
return p return p
liftIO $ n `shouldBe` 2
liftIO $ ret `shouldBe` [ Entity p2k (Person anon Nothing) liftIO $ ret `shouldBe` [ Entity p2k (Person anon Nothing)
, Entity p1k (Person anon (Just 72)) , Entity p1k (Person anon (Just 73))
, Entity p3k p3 ] , Entity p3k p3 ]
it "works with a subexpression having COUNT(*)" $ it "works with a subexpression having COUNT(*)" $