Add deleteCount and updateCount (fixes #10).
This commit is contained in:
parent
53823aeb12
commit
51c08ed0e8
@ -55,7 +55,7 @@ library
|
||||
build-depends:
|
||||
base >= 4.5 && < 4.7
|
||||
, text == 0.11.*
|
||||
, persistent == 1.1.*
|
||||
, persistent >= 1.1.5 && < 1.2
|
||||
, transformers >= 0.2
|
||||
, unordered-containers >= 0.2
|
||||
|
||||
|
||||
@ -48,7 +48,9 @@ module Database.Esqueleto
|
||||
, selectSource
|
||||
, selectDistinctSource
|
||||
, delete
|
||||
, deleteCount
|
||||
, update
|
||||
, updateCount
|
||||
|
||||
-- * Helpers
|
||||
, valkey
|
||||
|
||||
@ -20,7 +20,9 @@ module Database.Esqueleto.Internal.Sql
|
||||
, selectDistinct
|
||||
, selectDistinctSource
|
||||
, delete
|
||||
, deleteCount
|
||||
, update
|
||||
, updateCount
|
||||
-- * The guts
|
||||
, unsafeSqlBinOp
|
||||
, unsafeSqlValue
|
||||
@ -38,7 +40,7 @@ module Database.Esqueleto.Internal.Sql
|
||||
import Control.Applicative (Applicative(..), (<$>), (<$))
|
||||
import Control.Arrow ((***), first)
|
||||
import Control.Exception (throw, throwIO)
|
||||
import Control.Monad ((>=>), ap, MonadPlus(..))
|
||||
import Control.Monad ((>=>), ap, void, MonadPlus(..))
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Logger (MonadLogger)
|
||||
import Control.Monad.Trans.Class (lift)
|
||||
@ -49,7 +51,7 @@ import Data.Monoid (Monoid(..), (<>))
|
||||
import Database.Persist.EntityDef
|
||||
import Database.Persist.GenericSql
|
||||
import Database.Persist.GenericSql.Internal (Connection(escapeName, noLimit))
|
||||
import Database.Persist.GenericSql.Raw (execute, SqlBackend, withStmt)
|
||||
import Database.Persist.GenericSql.Raw (executeCount, SqlBackend, withStmt)
|
||||
import Database.Persist.Store hiding (delete)
|
||||
import qualified Control.Monad.Trans.Reader as R
|
||||
import qualified Control.Monad.Trans.State as S
|
||||
@ -589,10 +591,10 @@ rawExecute :: ( MonadLogger m
|
||||
, MonadResourceBase m )
|
||||
=> Mode
|
||||
-> SqlQuery ()
|
||||
-> SqlPersist m ()
|
||||
-> SqlPersist m Int64
|
||||
rawExecute mode query = do
|
||||
conn <- SqlPersist R.ask
|
||||
uncurry execute $
|
||||
uncurry executeCount $
|
||||
first builderToText $
|
||||
toRawSql mode conn query
|
||||
|
||||
@ -623,7 +625,15 @@ delete :: ( MonadLogger m
|
||||
, MonadResourceBase m )
|
||||
=> SqlQuery ()
|
||||
-> SqlPersist m ()
|
||||
delete = rawExecute DELETE
|
||||
delete = void . deleteCount
|
||||
|
||||
|
||||
-- | Same as 'delete', but returns the number of rows affected.
|
||||
deleteCount :: ( MonadLogger m
|
||||
, MonadResourceBase m )
|
||||
=> SqlQuery ()
|
||||
-> SqlPersist m Int64
|
||||
deleteCount = rawExecute DELETE
|
||||
|
||||
|
||||
-- | Execute an @esqueleto@ @UPDATE@ query inside @persistent@'s
|
||||
@ -643,7 +653,16 @@ update :: ( MonadLogger m
|
||||
, SqlEntity val )
|
||||
=> (SqlExpr (Entity val) -> SqlQuery ())
|
||||
-> SqlPersist m ()
|
||||
update = rawExecute UPDATE . from
|
||||
update = void . updateCount
|
||||
|
||||
|
||||
-- | Same as 'update', but returns the number of rows affected.
|
||||
updateCount :: ( MonadLogger m
|
||||
, MonadResourceBase m
|
||||
, SqlEntity val )
|
||||
=> (SqlExpr (Entity val) -> SqlQuery ())
|
||||
-> SqlPersist m Int64
|
||||
updateCount = rawExecute UPDATE . from
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
25
test/Test.hs
25
test/Test.hs
@ -434,19 +434,22 @@ main = do
|
||||
p1e <- insert' p1
|
||||
p2e <- insert' p2
|
||||
p3e <- insert' p3
|
||||
ret1 <- select $
|
||||
from $ \p -> do
|
||||
orderBy [asc (p ^. PersonName)]
|
||||
return p
|
||||
let getAll = select $
|
||||
from $ \p -> do
|
||||
orderBy [asc (p ^. PersonName)]
|
||||
return p
|
||||
ret1 <- getAll
|
||||
liftIO $ ret1 `shouldBe` [ p1e, p3e, p2e ]
|
||||
() <- delete $
|
||||
from $ \p ->
|
||||
where_ (p ^. PersonName ==. val (personName p1))
|
||||
ret2 <- select $
|
||||
from $ \p -> do
|
||||
orderBy [asc (p ^. PersonName)]
|
||||
return p
|
||||
ret2 <- getAll
|
||||
liftIO $ ret2 `shouldBe` [ p3e, p2e ]
|
||||
n <- deleteCount $
|
||||
from $ \p ->
|
||||
return ((p :: SqlExpr (Entity Person)) `seq` ())
|
||||
ret3 <- getAll
|
||||
liftIO $ (n, ret3) `shouldBe` (2, [])
|
||||
|
||||
describe "update" $ do
|
||||
it "works on a simple example" $
|
||||
@ -459,12 +462,16 @@ main = do
|
||||
set p [ PersonName =. val anon
|
||||
, PersonAge *=. just (val 2) ]
|
||||
where_ (p ^. PersonName !=. val "Mike")
|
||||
n <- updateCount $ \p -> do
|
||||
set p [ PersonAge +=. just (val 1) ]
|
||||
where_ (p ^. PersonName !=. val "Mike")
|
||||
ret <- select $
|
||||
from $ \p -> do
|
||||
orderBy [ asc (p ^. PersonName), asc (p ^. PersonAge) ]
|
||||
return p
|
||||
liftIO $ n `shouldBe` 2
|
||||
liftIO $ ret `shouldBe` [ Entity p2k (Person anon Nothing)
|
||||
, Entity p1k (Person anon (Just 72))
|
||||
, Entity p1k (Person anon (Just 73))
|
||||
, Entity p3k p3 ]
|
||||
|
||||
it "works with a subexpression having COUNT(*)" $
|
||||
|
||||
Loading…
Reference in New Issue
Block a user