diff --git a/changelog.md b/changelog.md index 1eec4fa..d0b74ad 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,9 @@ Unreleased (3.1.1) ======== +- @JoseD92 + - [#149](https://github.com/bitemyapp/esqueleto/pull/149): Added `upsert` support. + - @parsonsmatt - [#133](https://github.com/bitemyapp/esqueleto/pull/133): Added `renderQueryToText` and related functions. diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index c464807..d6d165b 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -1263,6 +1263,7 @@ data UnexpectedCaseError = | InsertionFinalError | NewIdentForError | UnsafeSqlCaseError + | OperationNotSupported deriving (Show) data SqlBinOpCompositeError = diff --git a/src/Database/Esqueleto/PostgreSQL.hs b/src/Database/Esqueleto/PostgreSQL.hs index b89cac7..cbbf788 100644 --- a/src/Database/Esqueleto/PostgreSQL.hs +++ b/src/Database/Esqueleto/PostgreSQL.hs @@ -18,6 +18,8 @@ module Database.Esqueleto.PostgreSQL , chr , now_ , random_ + , upsert + , upsertBy -- * Internal , unsafeSqlAggregateFunction ) where @@ -28,8 +30,17 @@ import Data.Semigroup import qualified Data.Text.Internal.Builder as TLB import Data.Time.Clock (UTCTime) import Database.Esqueleto.Internal.Language hiding (random_) -import Database.Esqueleto.Internal.PersistentImport +import Database.Esqueleto.Internal.PersistentImport hiding (upsert, upsertBy) import Database.Esqueleto.Internal.Sql +import Database.Esqueleto.Internal.Internal (EsqueletoError(..), CompositeKeyError(..), + UnexpectedCaseError(..), SetClause) +import Database.Persist.Class (OnlyOneUniqueKey) +import Data.List.NonEmpty ( NonEmpty( (:|) ) ) +import Control.Arrow ((***), first) +import Control.Exception (Exception, throw, throwIO) +import Control.Monad.IO.Class (MonadIO (..)) +import qualified Control.Monad.Trans.Reader as R + -- | (@random()@) Split out into database specific modules -- because MySQL uses `rand()`. @@ -152,3 +163,54 @@ chr = unsafeSqlFunction "chr" now_ :: SqlExpr (Value UTCTime) now_ = unsafeSqlValue "NOW()" + +upsert :: (MonadIO m, + PersistEntity record, + OnlyOneUniqueKey record, + PersistRecordBackend record SqlBackend, + IsPersistBackend (PersistEntityBackend record)) + => record + -- ^ new record to insert + -> [SqlExpr (Update record)] + -- ^ updates to perform if the record already exists + -> R.ReaderT SqlBackend m (Entity record) + -- ^ the record in the database after the operation +upsert record updates = do + uniqueKey <- onlyUnique record + upsertBy uniqueKey record updates + +upsertBy :: (MonadIO m, + PersistEntity record, + IsPersistBackend (PersistEntityBackend record)) + => Unique record + -- ^ uniqueness constraint to find by + -> record + -- ^ new record to insert + -> [SqlExpr (Update record)] + -- ^ updates to perform if the record already exists + -> R.ReaderT SqlBackend m (Entity record) + -- ^ the record in the database after the operation +upsertBy uniqueKey record updates = do + sqlB <- R.ask + maybe + (throw (UnexpectedCaseErr OperationNotSupported)) -- Postgres backend should have connUpsertSql, if this error is thrown, check changes on persistent + (handler sqlB) + (connUpsertSql sqlB) + where + addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey + entDef = entityDef (Just record) + uDef = head $ filter ((==) (persistUniqueToFieldNames uniqueKey) . uniqueFields) $ entityUniques entDef + updatesText conn = first builderToText $ renderUpdates conn updates + handler conn f = fmap head $ uncurry rawSql $ + (***) (f entDef (uDef :| [])) addVals $ updatesText conn + renderUpdates :: SqlBackend + -> [SqlExpr (Update val)] + -> (TLB.Builder, [PersistValue]) + renderUpdates conn = uncommas' . concatMap renderUpdate + where + mk :: SqlExpr (Value ()) -> [(TLB.Builder, [PersistValue])] + mk (ERaw _ f) = [f info] + mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME + renderUpdate :: SqlExpr (Update val) -> [(TLB.Builder, [PersistValue])] + renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused + info = (projectBackend conn, initialIdentState) \ No newline at end of file diff --git a/test/Common/Test.hs b/test/Common/Test.hs index a5450fb..9c6d367 100644 --- a/test/Common/Test.hs +++ b/test/Common/Test.hs @@ -25,11 +25,14 @@ module Common.Test , testAscRandom , testRandomMath , migrateAll + , migrateUnique , cleanDB + , cleanUniques , RunDbMonad , Run , p1, p2, p3, p4, p5 , l1, l2, l3 + , u1, u2, u3, u4 , insert' , EntityField (..) , Foo (..) @@ -48,6 +51,7 @@ module Common.Test , Point (..) , Circle (..) , Numbers (..) + , OneUnique(..) ) where import Control.Monad (forM_, replicateM, replicateM_, void) @@ -157,8 +161,14 @@ share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistUpperCase| double Double |] - - +-- Unique Test schema +share [mkPersist sqlSettings, mkMigrate "migrateUnique"] [persistUpperCase| + OneUnique + name String + value Int + UniqueValue value + deriving Eq Show +|] -- | this could be achieved with S.fromList, but not all lists -- have Ord instances @@ -196,7 +206,17 @@ l2 = Lord "Dorset" Nothing l3 :: Lord l3 = Lord "Chester" (Just 17) +u1 :: OneUnique +u1 = OneUnique "First" 0 +u2 :: OneUnique +u2 = OneUnique "Second" 1 + +u3 :: OneUnique +u3 = OneUnique "Third" 0 + +u4 :: OneUnique +u4 = OneUnique "First" 2 testSelect :: Run -> Spec testSelect run = do @@ -1536,3 +1556,10 @@ cleanDB = do delete $ from $ \(_ :: SqlExpr (Entity Point)) -> return () delete $ from $ \(_ :: SqlExpr (Entity Numbers)) -> return () + + +cleanUniques + :: (forall m. RunDbMonad m + => SqlPersistT (R.ResourceT m) ()) +cleanUniques = + delete $ from $ \(_ :: SqlExpr (Entity OneUnique)) -> return () \ No newline at end of file diff --git a/test/PostgreSQL/Test.hs b/test/PostgreSQL/Test.hs index 8f2c4a2..407cb41 100644 --- a/test/PostgreSQL/Test.hs +++ b/test/PostgreSQL/Test.hs @@ -33,7 +33,7 @@ import qualified Database.Esqueleto.PostgreSQL as EP import Database.Esqueleto.PostgreSQL.JSON hiding ((?.), (-.), (||.)) import qualified Database.Esqueleto.PostgreSQL.JSON as JSON import Database.Persist.Postgresql (withPostgresqlConn) -import Database.PostgreSQL.Simple (SqlError(..)) +import Database.PostgreSQL.Simple (SqlError(..), ExecStatus(..)) import System.Environment import Test.Hspec @@ -949,6 +949,34 @@ testHashMinusOperator = where_ $ v @>. jsonbVal (object []) where_ $ f v +testInsertUniqueViolation :: Spec +testInsertUniqueViolation = + describe "Unique Violation on Insert" $ + it "Unique throws exception" $ run (do + _ <- insert u1 + _ <- insert u2 + insert u3) `shouldThrow` (==) exception + where + exception = SqlError { + sqlState = "23505", + sqlExecStatus = FatalError, + sqlErrorMsg = "duplicate key value violates unique constraint \"UniqueValue\"", + sqlErrorDetail = "Key (value)=(0) already exists.", + sqlErrorHint = ""} + +testUpsert :: Spec +testUpsert = + describe "Upsert test" $ do + it "Upsert can insert like normal" $ run $ do + u1e <- EP.upsert u1 [OneUniqueName =. val "fifth"] + liftIO $ entityVal u1e `shouldBe` u1 + it "Upsert performs update on collision" $ run $ do + u1e <- EP.upsert u1 [OneUniqueName =. val "fifth"] + liftIO $ entityVal u1e `shouldBe` u1 + u2e <- EP.upsert u2 [OneUniqueName =. val "fifth"] + liftIO $ entityVal u2e `shouldBe` u2 + u3e <- EP.upsert u3 [OneUniqueName =. val "fifth"] + liftIO $ entityVal u3e `shouldBe` u1{oneUniqueName="fifth"} type JSONValue = Maybe (JSONB A.Value) @@ -1021,6 +1049,8 @@ main = do testPostgresqlUpdate testPostgresqlCoalesce testPostgresqlTextFunctions + testInsertUniqueViolation + testUpsert describe "PostgreSQL JSON tests" $ do -- NOTE: We only clean the table once, so we -- can use its contents across all JSON tests @@ -1053,7 +1083,9 @@ run_worker act = withConn $ runSqlConn (migrateIt >> act) migrateIt :: RunDbMonad m => SqlPersistT (R.ResourceT m) () migrateIt = do void $ runMigrationSilent migrateAll + void $ runMigrationSilent migrateUnique cleanDB + cleanUniques withConn :: RunDbMonad m => (SqlBackend -> R.ResourceT m a) -> m a withConn =