Insert Select With Conflict for postgres (#155)

* add insertSelectWithConflict to allow insert with conflict resolution

* insertSelectWithConflictCount does nothing when no updates given and add tests

* no longer require undefined for insertSelectWithConflict

* Update src/Database/Esqueleto/PostgreSQL.hs

Co-Authored-By: Matt Parsons <parsonsmatt@gmail.com>

* Update src/Database/Esqueleto/PostgreSQL.hs

Co-Authored-By: Matt Parsons <parsonsmatt@gmail.com>

* Update src/Database/Esqueleto/PostgreSQL.hs

Co-Authored-By: Matt Parsons <parsonsmatt@gmail.com>

* Update src/Database/Esqueleto/PostgreSQL.hs

Co-Authored-By: Matt Parsons <parsonsmatt@gmail.com>

* Move non postgres related functions out of postgres module to internal.

* add entry to changelog
This commit is contained in:
Jose Duran 2019-10-28 12:56:34 -05:00 committed by Matt Parsons
parent 7608d716a1
commit 40f7a0ca97
5 changed files with 199 additions and 15 deletions

View File

@ -1,3 +1,9 @@
Unreleased (3.1.3)
========
- @JoseD92
- [#155](https://github.com/bitemyapp/esqueleto/pull/149): Added `insertSelectWithConflict` postgres function.
3.1.2 3.1.2
======== ========

View File

@ -887,6 +887,54 @@ instance ( ToSomeValues a
toSomeValues c ++ toSomeValues d ++ toSomeValues e ++ toSomeValues f ++ toSomeValues c ++ toSomeValues d ++ toSomeValues e ++ toSomeValues f ++
toSomeValues g ++ toSomeValues h toSomeValues g ++ toSomeValues h
type family KnowResult a where
KnowResult (i -> o) = KnowResult o
KnowResult a = a
-- | A class for constructors or function which result type is known.
--
-- @since 3.1.3
class FinalResult a where
finalR :: a -> KnowResult a
instance FinalResult (Unique val) where
finalR = id
instance (FinalResult b) => FinalResult (a -> b) where
finalR f = finalR (f undefined)
-- | Convert a constructor for a 'Unique' key on a record to the 'UniqueDef' that defines it. You
-- can supply just the constructor itself, or a value of the type - the library is capable of figuring
-- it out from there.
--
-- @since 3.1.3
toUniqueDef :: forall a val. (KnowResult a ~ (Unique val), PersistEntity val,FinalResult a) =>
a -> UniqueDef
toUniqueDef uniqueConstructor = uniqueDef
where
proxy :: Proxy val
proxy = Proxy
unique :: Unique val
unique = finalR uniqueConstructor
-- there must be a better way to get the constrain name from a unique, make this not a list search
filterF = (==) (persistUniqueToFieldNames unique) . uniqueFields
uniqueDef = head . filter filterF . entityUniques . entityDef $ proxy
-- | Render updates to be use in a SET clause for a given sql backend.
--
-- @since 3.1.3
renderUpdates :: (BackendCompatible SqlBackend backend) =>
backend
-> [SqlExpr (Update val)]
-> (TLB.Builder, [PersistValue])
renderUpdates conn = uncommas' . concatMap renderUpdate
where
mk :: SqlExpr (Value ()) -> [(TLB.Builder, [PersistValue])]
mk (ERaw _ f) = [f info]
mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME
renderUpdate :: SqlExpr (Update val) -> [(TLB.Builder, [PersistValue])]
renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused
info = (projectBackend conn, initialIdentState)
-- | Data type that represents an @INNER JOIN@ (see 'LeftOuterJoin' for an example). -- | Data type that represents an @INNER JOIN@ (see 'LeftOuterJoin' for an example).
data InnerJoin a b = a `InnerJoin` b data InnerJoin a b = a `InnerJoin` b

View File

@ -1,6 +1,7 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings {-# LANGUAGE OverloadedStrings
, GADTs, CPP , GADTs, CPP, Rank2Types
, ScopedTypeVariables
#-} #-}
-- | This module contain PostgreSQL-specific functions. -- | This module contain PostgreSQL-specific functions.
-- --
@ -20,6 +21,8 @@ module Database.Esqueleto.PostgreSQL
, random_ , random_
, upsert , upsert
, upsertBy , upsertBy
, insertSelectWithConflict
, insertSelectWithConflictCount
-- * Internal -- * Internal
, unsafeSqlAggregateFunction , unsafeSqlAggregateFunction
) where ) where
@ -33,15 +36,19 @@ import Database.Esqueleto.Internal.Language hiding (random_)
import Database.Esqueleto.Internal.PersistentImport hiding (upsert, upsertBy) import Database.Esqueleto.Internal.PersistentImport hiding (upsert, upsertBy)
import Database.Esqueleto.Internal.Sql import Database.Esqueleto.Internal.Sql
import Database.Esqueleto.Internal.Internal (EsqueletoError(..), CompositeKeyError(..), import Database.Esqueleto.Internal.Internal (EsqueletoError(..), CompositeKeyError(..),
UnexpectedCaseError(..), SetClause) UnexpectedCaseError(..), SetClause, Ident(..),
uncommas, FinalResult(..), toUniqueDef,
KnowResult, renderUpdates)
import Database.Persist.Class (OnlyOneUniqueKey) import Database.Persist.Class (OnlyOneUniqueKey)
import Data.List.NonEmpty ( NonEmpty( (:|) ) ) import Data.List.NonEmpty ( NonEmpty( (:|) ) )
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Control.Arrow ((***), first) import Control.Arrow ((***), first)
import Control.Exception (Exception, throw, throwIO) import Control.Exception (Exception, throw, throwIO)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO (..)) import Control.Monad.IO.Class (MonadIO (..))
import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.Reader as R
-- | (@random()@) Split out into database specific modules -- | (@random()@) Split out into database specific modules
-- because MySQL uses `rand()`. -- because MySQL uses `rand()`.
-- --
@ -199,18 +206,95 @@ upsertBy uniqueKey record updates = do
where where
addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey
entDef = entityDef (Just record) entDef = entityDef (Just record)
uDef = head $ filter ((==) (persistUniqueToFieldNames uniqueKey) . uniqueFields) $ entityUniques entDef uDef = toUniqueDef uniqueKey
updatesText conn = first builderToText $ renderUpdates conn updates updatesText conn = first builderToText $ renderUpdates conn updates
handler conn f = fmap head $ uncurry rawSql $ handler conn f = fmap head $ uncurry rawSql $
(***) (f entDef (uDef :| [])) addVals $ updatesText conn (***) (f entDef (uDef :| [])) addVals $ updatesText conn
renderUpdates :: SqlBackend
-> [SqlExpr (Update val)] -- | Inserts into a table the results of a query similar to 'insertSelect' but allows
-> (TLB.Builder, [PersistValue]) -- to update values that violate a constraint during insertions.
renderUpdates conn = uncommas' . concatMap renderUpdate --
where -- Example of usage:
mk :: SqlExpr (Value ()) -> [(TLB.Builder, [PersistValue])] --
mk (ERaw _ f) = [f info] -- @
mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME -- share [ mkPersist sqlSettings
renderUpdate :: SqlExpr (Update val) -> [(TLB.Builder, [PersistValue])] -- , mkDeleteCascade sqlSettings
renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused -- , mkMigrate "migrate"
info = (projectBackend conn, initialIdentState) -- ] [persistLowerCase|
-- Bar
-- num Int
-- deriving Eq Show
-- Foo
-- num Int
-- UniqueFoo num
-- deriving Eq Show
-- |]
--
-- insertSelectWithConflict
-- UniqueFoo -- (UniqueFoo undefined) or (UniqueFoo anyNumber) would also work
-- (from $ \b ->
-- return $ Foo <# (b ^. BarNum)
-- )
-- (\current excluded ->
-- [FooNum =. (current ^. FooNum) +. (excluded ^. FooNum)]
-- )
-- @
--
-- Inserts to table Foo all Bar.num values and in case of conflict SomeFooUnique,
-- the conflicting value is updated to the current plus the excluded.
--
-- @since 3.1.3
insertSelectWithConflict :: forall a m val. (
FinalResult a,
KnowResult a ~ (Unique val),
MonadIO m,
PersistEntity val) =>
a
-- ^ Unique constructor or a unique, this is used just to get the name of the postgres constraint, the value(s) is(are) never used, so if you have a unique "MyUnique 0", "MyUnique undefined" would work as well.
-> SqlQuery (SqlExpr (Insertion val))
-- ^ Insert query.
-> (SqlExpr (Entity val) -> SqlExpr (Entity val) -> [SqlExpr (Update val)])
-- ^ A list of updates to be applied in case of the constraint being violated. The expression takes the current and excluded value to produce the updates.
-> SqlWriteT m ()
insertSelectWithConflict unique query = void . insertSelectWithConflictCount unique query
-- | Same as 'insertSelectWithConflict' but returns the number of rows affected.
--
-- @since 3.1.3
insertSelectWithConflictCount :: forall a val m. (
FinalResult a,
KnowResult a ~ (Unique val),
MonadIO m,
PersistEntity val) =>
a
-> SqlQuery (SqlExpr (Insertion val))
-> (SqlExpr (Entity val) -> SqlExpr (Entity val) -> [SqlExpr (Update val)])
-> SqlWriteT m Int64
insertSelectWithConflictCount unique query conflictQuery = do
conn <- R.ask
uncurry rawExecuteCount $
combine
(toRawSql INSERT_INTO (conn, initialIdentState) (fmap EInsertFinal query))
(conflict conn)
where
proxy :: Proxy val
proxy = Proxy
updates = conflictQuery entCurrent entExcluded
combine (tlb1,vals1) (tlb2,vals2) = (builderToText (tlb1 `mappend` tlb2), vals1 ++ vals2)
entExcluded = EEntity $ I "excluded"
tableName = unDBName . entityDB . entityDef
entCurrent = EEntity $ I (tableName proxy)
uniqueDef = toUniqueDef unique
constraint = TLB.fromText . unDBName . uniqueDBName $ uniqueDef
renderedUpdates :: (BackendCompatible SqlBackend backend) => backend -> (TLB.Builder, [PersistValue])
renderedUpdates conn = renderUpdates conn updates
conflict conn = (foldr1 mappend ([
TLB.fromText "ON CONFLICT ON CONSTRAINT \"",
constraint,
TLB.fromText "\" DO "
] ++ if null updates then [TLB.fromText "NOTHING"] else [
TLB.fromText "UPDATE SET ",
updatesTLB
]),values)
where
(updatesTLB,values) = renderedUpdates conn

View File

@ -52,6 +52,7 @@ module Common.Test
, Circle (..) , Circle (..)
, Numbers (..) , Numbers (..)
, OneUnique(..) , OneUnique(..)
, Unique(..)
) where ) where
import Control.Monad (forM_, replicateM, replicateM_, void) import Control.Monad (forM_, replicateM, replicateM_, void)

View File

@ -978,6 +978,50 @@ testUpsert =
u3e <- EP.upsert u3 [OneUniqueName =. val "fifth"] u3e <- EP.upsert u3 [OneUniqueName =. val "fifth"]
liftIO $ entityVal u3e `shouldBe` u1{oneUniqueName="fifth"} liftIO $ entityVal u3e `shouldBe` u1{oneUniqueName="fifth"}
testInsertSelectWithConflict :: Spec
testInsertSelectWithConflict =
describe "insertSelectWithConflict test" $ do
it "Should do Nothing when no updates set" $ run $ do
_ <- insert p1
_ <- insert p2
_ <- insert p3
n1 <- EP.insertSelectWithConflictCount UniqueValue (
from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum)
)
(\current excluded -> [])
uniques1 <- select $ from $ \u -> return u
n2 <- EP.insertSelectWithConflictCount UniqueValue (
from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum)
)
(\current excluded -> [])
uniques2 <- select $ from $ \u -> return u
liftIO $ n1 `shouldBe` 3
liftIO $ n2 `shouldBe` 0
let test = map (OneUnique "test" . personFavNum) [p1,p2,p3]
liftIO $ map entityVal uniques1 `shouldBe` test
liftIO $ map entityVal uniques2 `shouldBe` test
it "Should update a value if given an update on conflict" $ run $ do
_ <- insert p1
_ <- insert p2
_ <- insert p3
-- Note, have to sum 4 so that the update does not conflicts again with another row.
n1 <- EP.insertSelectWithConflictCount UniqueValue (
from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum)
)
(\current excluded -> [OneUniqueValue =. val 4 +. (current ^. OneUniqueValue) +. (excluded ^. OneUniqueValue)])
uniques1 <- select $ from $ \u -> return u
n2 <- EP.insertSelectWithConflictCount UniqueValue (
from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum)
)
(\current excluded -> [OneUniqueValue =. val 4 +. (current ^. OneUniqueValue) +. (excluded ^. OneUniqueValue)])
uniques2 <- select $ from $ \u -> return u
liftIO $ n1 `shouldBe` 3
liftIO $ n2 `shouldBe` 3
let test = map (OneUnique "test" . personFavNum) [p1,p2,p3]
test2 = map (OneUnique "test" . (+4) . (*2) . personFavNum) [p1,p2,p3]
liftIO $ map entityVal uniques1 `shouldBe` test
liftIO $ map entityVal uniques2 `shouldBe` test2
type JSONValue = Maybe (JSONB A.Value) type JSONValue = Maybe (JSONB A.Value)
createSaneSQL :: (PersistField a) => SqlExpr (Value a) -> T.Text -> [PersistValue] -> IO () createSaneSQL :: (PersistField a) => SqlExpr (Value a) -> T.Text -> [PersistValue] -> IO ()
@ -1051,6 +1095,7 @@ main = do
testPostgresqlTextFunctions testPostgresqlTextFunctions
testInsertUniqueViolation testInsertUniqueViolation
testUpsert testUpsert
testInsertSelectWithConflict
describe "PostgreSQL JSON tests" $ do describe "PostgreSQL JSON tests" $ do
-- NOTE: We only clean the table once, so we -- NOTE: We only clean the table once, so we
-- can use its contents across all JSON tests -- can use its contents across all JSON tests