diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 773fec1..d082191 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -100,7 +100,6 @@ module Database.Esqueleto -- * Helpers , valkey , valJ - , EsqueletoUpsert(..) -- * Re-exports -- $reexports diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index 8db442f..88ce078 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -56,9 +56,6 @@ import qualified Data.Text.Lazy.Builder as TLB import Data.Typeable (Typeable) import Text.Blaze.Html (Html) -import Database.Persist.Class (OnlyOneUniqueKey) -import Control.Monad.Reader (ReaderT) -import Data.List.NonEmpty( NonEmpty( (:|) ) ) -- | (Internal) Start a 'from' query with an entity. 'from' -- does two kinds of magic using 'fromStart', 'fromJoin' and @@ -1266,6 +1263,7 @@ data UnexpectedCaseError = | InsertionFinalError | NewIdentForError | UnsafeSqlCaseError + | OperationNotSupported deriving (Show) data SqlBinOpCompositeError = @@ -2885,88 +2883,4 @@ insertSelect = void . insertSelectCount -- | Insert a 'PersistField' for every selected value, return the count afterward insertSelectCount :: (MonadIO m, PersistEntity a) => SqlQuery (SqlExpr (Insertion a)) -> SqlWriteT m Int64 -insertSelectCount = rawEsqueleto INSERT_INTO . fmap EInsertFinal - --- | A class for allowing the use of upsert operation using --- esqueleto's types. -class (PersistUniqueWrite backend, - PersistQueryWrite backend, - IsPersistBackend (BaseBackend backend), - BackendCompatible SqlBackend backend, - BackendCompatible SqlBackend (BaseBackend backend)) => - EsqueletoUpsert backend where - upsert - :: (MonadIO m, PersistRecordBackend record backend, OnlyOneUniqueKey record) - => record - -- ^ new record to insert - -> [SqlExpr (Update record)] - -- ^ updates to perform if the record already exists - -> ReaderT backend m (Entity record) - -- ^ the record in the database after the operation - upsert record updates = do - uniqueKey <- onlyUnique record - upsertBy uniqueKey record updates - - upsertBy :: (MonadIO m, PersistRecordBackend record backend) - => Unique record - -- ^ uniqueness constraint to find by - -> record - -- ^ new record to insert - -> [SqlExpr (Update record)] - -- ^ updates to perform if the record already exists - -> ReaderT backend m (Entity record) - -- ^ the record in the database after the operation - upsertBy = defaultUpsert - -defaultUpsert - :: (MonadIO m, PersistRecordBackend record backend, - PersistQueryWrite backend, - PersistUniqueWrite backend, - IsPersistBackend (BaseBackend backend), - BackendCompatible SqlBackend backend, - BackendCompatible SqlBackend (BaseBackend backend)) - => Unique record - -> record - -> [SqlExpr (Update record)] - -> ReaderT backend m (Entity record) -defaultUpsert uniqueKey record updates = do - mrecord <- getBy uniqueKey - maybe (insertEntity record) updateGetEntity mrecord - where - updateGetEntity (Entity k _) = fmap head $ do - update $ \r -> do - set r updates - where_ (r ^. persistIdField ==. val k) - select $ from $ \r -> do - where_ (r ^. persistIdField ==. val k) - return r - --- Currently only postgres implements connUpsertSql, check that '?' are --- added in the same order as postgres when adding connUpsertSql to another --- backend. -instance EsqueletoUpsert SqlBackend where - upsertBy uniqueKey record updates = do - sqlB <- R.ask - maybe - (defaultUpsert uniqueKey record updates) - (handler sqlB) - (connUpsertSql sqlB) - where - addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey - entDef = entityDef (Just record) - uDef = head $ filter ((==) (persistUniqueToFieldNames uniqueKey) . uniqueFields) $ entityUniques entDef - updatesText conn = first builderToText $ renderUpdates conn updates - handler conn f = fmap head $ uncurry rawSql $ - (***) (f entDef (uDef :| [])) addVals $ updatesText conn - --- | Renders a [SqlExpr (Update val)] into a (TLB.Builder, [PersistValue]) with a given backend. -renderUpdates :: BackendCompatible SqlBackend backend => - backend - -> [SqlExpr (Update val)] - -> (TLB.Builder, [PersistValue]) -renderUpdates conn = uncommas' . concatMap renderUpdate - where - mk (ERaw _ f) = [f info] - mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME - renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused - info = (projectBackend conn, initialIdentState) \ No newline at end of file +insertSelectCount = rawEsqueleto INSERT_INTO . fmap EInsertFinal \ No newline at end of file diff --git a/src/Database/Esqueleto/Internal/PersistentImport.hs b/src/Database/Esqueleto/Internal/PersistentImport.hs index 43725b0..638d538 100644 --- a/src/Database/Esqueleto/Internal/PersistentImport.hs +++ b/src/Database/Esqueleto/Internal/PersistentImport.hs @@ -147,4 +147,4 @@ import Database.Persist.Sql hiding , selectKeysList, deleteCascadeWhere, (=.), (+=.), (-=.), (*=.), (/=.) , (==.), (!=.), (<.), (>.), (<=.), (>=.), (<-.), (/<-.), (||.) , listToJSON, mapToJSON, getPersistMap, limitOffsetOrder, selectSource - , update , count , upsertBy, upsert) + , update , count) diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index e50e6e4..7818624 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -71,7 +71,6 @@ module Database.Esqueleto.Internal.Sql , parens , toArgList , builderToText - , EsqueletoUpsert(..) ) where import Database.Esqueleto.Internal.Internal diff --git a/src/Database/Esqueleto/PostgreSQL.hs b/src/Database/Esqueleto/PostgreSQL.hs index b89cac7..f9254dd 100644 --- a/src/Database/Esqueleto/PostgreSQL.hs +++ b/src/Database/Esqueleto/PostgreSQL.hs @@ -18,6 +18,8 @@ module Database.Esqueleto.PostgreSQL , chr , now_ , random_ + , upsert + , upsertBy -- * Internal , unsafeSqlAggregateFunction ) where @@ -28,8 +30,17 @@ import Data.Semigroup import qualified Data.Text.Internal.Builder as TLB import Data.Time.Clock (UTCTime) import Database.Esqueleto.Internal.Language hiding (random_) -import Database.Esqueleto.Internal.PersistentImport +import Database.Esqueleto.Internal.PersistentImport hiding (upsert, upsertBy) import Database.Esqueleto.Internal.Sql +import Database.Esqueleto.Internal.Internal (EsqueletoError(..), CompositeKeyError(..), + UnexpectedCaseError(..)) +import Database.Persist.Class (OnlyOneUniqueKey) +import Data.List.NonEmpty ( NonEmpty( (:|) ) ) +import Control.Arrow ((***), first) +import Control.Exception (Exception, throw, throwIO) +import Control.Monad.IO.Class (MonadIO (..)) +import qualified Control.Monad.Trans.Reader as R + -- | (@random()@) Split out into database specific modules -- because MySQL uses `rand()`. @@ -152,3 +163,52 @@ chr = unsafeSqlFunction "chr" now_ :: SqlExpr (Value UTCTime) now_ = unsafeSqlValue "NOW()" + +upsert :: (MonadIO m, + PersistEntity record, + OnlyOneUniqueKey record, + PersistRecordBackend record SqlBackend, + IsPersistBackend (PersistEntityBackend record)) + => record + -- ^ new record to insert + -> [SqlExpr (Update record)] + -- ^ updates to perform if the record already exists + -> R.ReaderT SqlBackend m (Entity record) + -- ^ the record in the database after the operation +upsert record updates = do + uniqueKey <- onlyUnique record + upsertBy uniqueKey record updates + +upsertBy :: (MonadIO m, + PersistEntity record, + IsPersistBackend (PersistEntityBackend record)) + => Unique record + -- ^ uniqueness constraint to find by + -> record + -- ^ new record to insert + -> [SqlExpr (Update record)] + -- ^ updates to perform if the record already exists + -> R.ReaderT SqlBackend m (Entity record) + -- ^ the record in the database after the operation +upsertBy uniqueKey record updates = do + sqlB <- R.ask + maybe + (throw (UnexpectedCaseErr OperationNotSupported)) -- Postgres backend should have connUpsertSql, if this error is thrown, check changes on persistent + (handler sqlB) + (connUpsertSql sqlB) + where + addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey + entDef = entityDef (Just record) + uDef = head $ filter ((==) (persistUniqueToFieldNames uniqueKey) . uniqueFields) $ entityUniques entDef + updatesText conn = first builderToText $ renderUpdates conn updates + handler conn f = fmap head $ uncurry rawSql $ + (***) (f entDef (uDef :| [])) addVals $ updatesText conn + renderUpdates :: SqlBackend + -> [SqlExpr (Update val)] + -> (TLB.Builder, [PersistValue]) + renderUpdates conn = uncommas' . concatMap renderUpdate + where + mk (ERaw _ f) = [f info] + mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME + renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused + info = (projectBackend conn, initialIdentState) \ No newline at end of file