diff --git a/esqueleto.cabal b/esqueleto.cabal index b89ad99..965b082 100644 --- a/esqueleto.cabal +++ b/esqueleto.cabal @@ -49,6 +49,7 @@ library , bytestring , conduit >=1.3 , monad-logger + , mtl , persistent >=2.10.0 && <2.11 , resourcet >=1.2 , tagged >=0.2 diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index d082191..773fec1 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -100,6 +100,7 @@ module Database.Esqueleto -- * Helpers , valkey , valJ + , EsqueletoUpsert(..) -- * Re-exports -- $reexports diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index c464807..8db442f 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -56,6 +56,9 @@ import qualified Data.Text.Lazy.Builder as TLB import Data.Typeable (Typeable) import Text.Blaze.Html (Html) +import Database.Persist.Class (OnlyOneUniqueKey) +import Control.Monad.Reader (ReaderT) +import Data.List.NonEmpty( NonEmpty( (:|) ) ) -- | (Internal) Start a 'from' query with an entity. 'from' -- does two kinds of magic using 'fromStart', 'fromJoin' and @@ -2883,3 +2886,87 @@ insertSelect = void . insertSelectCount insertSelectCount :: (MonadIO m, PersistEntity a) => SqlQuery (SqlExpr (Insertion a)) -> SqlWriteT m Int64 insertSelectCount = rawEsqueleto INSERT_INTO . fmap EInsertFinal + +-- | A class for allowing the use of upsert operation using +-- esqueleto's types. +class (PersistUniqueWrite backend, + PersistQueryWrite backend, + IsPersistBackend (BaseBackend backend), + BackendCompatible SqlBackend backend, + BackendCompatible SqlBackend (BaseBackend backend)) => + EsqueletoUpsert backend where + upsert + :: (MonadIO m, PersistRecordBackend record backend, OnlyOneUniqueKey record) + => record + -- ^ new record to insert + -> [SqlExpr (Update record)] + -- ^ updates to perform if the record already exists + -> ReaderT backend m (Entity record) + -- ^ the record in the database after the operation + upsert record updates = do + uniqueKey <- onlyUnique record + upsertBy uniqueKey record updates + + upsertBy :: (MonadIO m, PersistRecordBackend record backend) + => Unique record + -- ^ uniqueness constraint to find by + -> record + -- ^ new record to insert + -> [SqlExpr (Update record)] + -- ^ updates to perform if the record already exists + -> ReaderT backend m (Entity record) + -- ^ the record in the database after the operation + upsertBy = defaultUpsert + +defaultUpsert + :: (MonadIO m, PersistRecordBackend record backend, + PersistQueryWrite backend, + PersistUniqueWrite backend, + IsPersistBackend (BaseBackend backend), + BackendCompatible SqlBackend backend, + BackendCompatible SqlBackend (BaseBackend backend)) + => Unique record + -> record + -> [SqlExpr (Update record)] + -> ReaderT backend m (Entity record) +defaultUpsert uniqueKey record updates = do + mrecord <- getBy uniqueKey + maybe (insertEntity record) updateGetEntity mrecord + where + updateGetEntity (Entity k _) = fmap head $ do + update $ \r -> do + set r updates + where_ (r ^. persistIdField ==. val k) + select $ from $ \r -> do + where_ (r ^. persistIdField ==. val k) + return r + +-- Currently only postgres implements connUpsertSql, check that '?' are +-- added in the same order as postgres when adding connUpsertSql to another +-- backend. +instance EsqueletoUpsert SqlBackend where + upsertBy uniqueKey record updates = do + sqlB <- R.ask + maybe + (defaultUpsert uniqueKey record updates) + (handler sqlB) + (connUpsertSql sqlB) + where + addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey + entDef = entityDef (Just record) + uDef = head $ filter ((==) (persistUniqueToFieldNames uniqueKey) . uniqueFields) $ entityUniques entDef + updatesText conn = first builderToText $ renderUpdates conn updates + handler conn f = fmap head $ uncurry rawSql $ + (***) (f entDef (uDef :| [])) addVals $ updatesText conn + +-- | Renders a [SqlExpr (Update val)] into a (TLB.Builder, [PersistValue]) with a given backend. +renderUpdates :: BackendCompatible SqlBackend backend => + backend + -> [SqlExpr (Update val)] + -> (TLB.Builder, [PersistValue]) +renderUpdates conn = uncommas' . concatMap renderUpdate + where + mk (ERaw _ f) = [f info] + mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME + renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused + info = (projectBackend conn, initialIdentState) \ No newline at end of file diff --git a/src/Database/Esqueleto/Internal/PersistentImport.hs b/src/Database/Esqueleto/Internal/PersistentImport.hs index 55e460b..43725b0 100644 --- a/src/Database/Esqueleto/Internal/PersistentImport.hs +++ b/src/Database/Esqueleto/Internal/PersistentImport.hs @@ -147,4 +147,4 @@ import Database.Persist.Sql hiding , selectKeysList, deleteCascadeWhere, (=.), (+=.), (-=.), (*=.), (/=.) , (==.), (!=.), (<.), (>.), (<=.), (>=.), (<-.), (/<-.), (||.) , listToJSON, mapToJSON, getPersistMap, limitOffsetOrder, selectSource - , update , count ) + , update , count , upsertBy, upsert) diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 7818624..e50e6e4 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -71,6 +71,7 @@ module Database.Esqueleto.Internal.Sql , parens , toArgList , builderToText + , EsqueletoUpsert(..) ) where import Database.Esqueleto.Internal.Internal