This repository has been archived on 2024-10-24. You can view files and clone it, but cannot push or open issues or pull requests.
fradrive-old/src/Utils/Sql.hs

138 lines
6.9 KiB
Haskell

-- SPDX-FileCopyrightText: 2022 Gregor Kleen <gregor.kleen@ifi.lmu.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
module Utils.Sql
( setSerializable
, setSerializableBatch, setSerializableReadOnlyBatch, setSerializableDeferrableBatch
, SerializableMode(..), setSerializable'
, catchSql, handleSql
, isUniqueConstraintViolation
, catchIfSql, handleIfSql
, trySql
) where
import ClassyPrelude.Yesod hiding (handle)
import Numeric.Natural
import Settings.Log
import Database.PostgreSQL.Simple (SqlError(..))
import Database.PostgreSQL.Simple.Errors (isSerializationError)
import Control.Monad.Catch
import Database.Persist.Sql hiding (IsolationLevel(..))
import qualified Database.Persist.Sql as Persist (IsolationLevel(..))
import Database.Persist.Sql.Types.Instances ()
import Database.Persist.Sql.Raw.QQ
import qualified Data.ByteString as ByteString
import Control.Retry
import Control.Lens ((&))
import qualified Data.UUID as UUID
import Control.Monad.Random.Class (MonadRandom(getRandom))
import Text.Shakespeare.Text (st)
import Control.Concurrent.Async (ExceptionInLinkedThread(..))
import Data.Universe
import Control.Monad.Trans.Reader (withReaderT)
fromExceptionWrapped :: Exception exc => SomeException -> Maybe exc
fromExceptionWrapped (fromException -> Just exc) = Just exc
fromExceptionWrapped (fromException >=> \(ExceptionInLinkedThread _ exc') -> fromExceptionWrapped exc' -> Just exc) = Just exc
fromExceptionWrapped _ = Nothing
data SerializableMode = Serializable
| SerializableReadOnly
| SerializableReadOnlyDeferrable
deriving (Eq, Ord, Read, Show, Enum, Bounded, Generic)
deriving anyclass (Universe, Finite)
setSerializable :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a
setSerializable = setSerializable' Serializable $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 10e6
setSerializableBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a
setSerializableBatch = setSerializable' Serializable $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6
setSerializableReadOnlyBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (ReaderT SqlReadBackend m)) => ReaderT SqlReadBackend m a -> ReaderT SqlReadBackend m a
setSerializableReadOnlyBatch = setSerializable' SerializableReadOnly $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6
setSerializableDeferrableBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (ReaderT SqlReadBackend m)) => ReaderT SqlReadBackend m a -> ReaderT SqlReadBackend m a
setSerializableDeferrableBatch = setSerializable' SerializableReadOnlyDeferrable $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6
setSerializable' :: forall backend m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (ReaderT backend m), BackendCompatible SqlBackend backend) => SerializableMode -> RetryPolicyM (ReaderT backend m) -> ReaderT backend m a -> ReaderT backend m a
setSerializable' mode policy act = do
LogSettings{logSerializableTransactionRetryLimit} <- readLogSettings
didCommit <- newTVarIO False
recovering policy (skipAsyncExceptions `snoc` logRetries suggestRetry (logRetry logSerializableTransactionRetryLimit)) $ act' didCommit
where
suggestRetry :: SomeException -> ReaderT backend m Bool
suggestRetry = return . maybe False isSerializationError . fromExceptionWrapped
logRetry :: Maybe Natural
-> Bool -- ^ Will retry
-> SomeException
-> RetryStatus
-> ReaderT backend m ()
logRetry _ shouldRetry@False err status = $logErrorS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
logRetry (Just limit) shouldRetry err status
| fromIntegral limit <= rsIterNumber status = $logInfoS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
logRetry _ shouldRetry err status = $logDebugS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
(setTransactionLevel, beginTransactionLevel) = case mode of
Serializable -> ([executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE|], [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE|] )
SerializableReadOnly -> ([executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY|], [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY|] )
SerializableReadOnlyDeferrable -> ([executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE|], [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE|])
act' :: TVar Bool -> RetryStatus -> ReaderT backend m a
act' didCommit RetryStatus{..} = do
prevCommited <- atomically $ swapTVar didCommit False
$logDebugS "SQL.setSerializable" $ "prevCommited = " <> tshow prevCommited <> "; rsIterNumber = " <> tshow rsIterNumber
if
| rsIterNumber == 0 -> setTransactionLevel *> act''
| prevCommited -> beginTransactionLevel *> act''
| otherwise -> withReaderT projectBackend transactionUndo *> setTransactionLevel *> act''
where act'' = do
res <- act
atomically $ writeTVar didCommit True
withReaderT projectBackend $ transactionSaveWithIsolation Persist.ReadCommitted
return res
catchSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => SqlPersistT m a -> (e -> SqlPersistT m a) -> SqlPersistT m a
catchSql = flip handleSql
handleSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => (e -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a
handleSql recover act = do
savepointName <- liftIO $ UUID.toString <$> getRandom
let recover' :: SomeException -> SqlPersistT m a
recover' (fromExceptionWrapped -> Just exc) = do
rawExecute [st|ROLLBACK TO SAVEPOINT "#{savepointName}"|] []
recover exc
recover' exc = throwM exc
handle recover' $ do
rawExecute [st|SAVEPOINT "#{savepointName}"|] []
res <- act
rawExecute [st|RELEASE SAVEPOINT "#{savepointName}"|] []
return res
catchIfSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => (e -> Bool) -> SqlPersistT m a -> (e -> SqlPersistT m a) -> SqlPersistT m a
catchIfSql p = flip $ handleIfSql p
handleIfSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => (e -> Bool) -> (e -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a
handleIfSql p recover = handleSql (\err -> bool throwM recover (p err) err)
trySql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => SqlPersistT m a -> SqlPersistT m (Either e a)
trySql = handleSql (return . Left) . fmap Right
isUniqueConstraintViolation :: SqlError -> Bool
isUniqueConstraintViolation SqlError{..} = "duplicate key value violates unique constraint" `ByteString.isPrefixOf` sqlErrorMsg