105 lines
4.7 KiB
Haskell
105 lines
4.7 KiB
Haskell
module Utils.Sql
|
|
( setSerializable, setSerializableBatch, setSerializable'
|
|
, catchSql, handleSql
|
|
, isUniqueConstraintViolation
|
|
, catchIfSql, handleIfSql
|
|
) where
|
|
|
|
import ClassyPrelude.Yesod hiding (handle)
|
|
import Numeric.Natural
|
|
import Settings.Log
|
|
|
|
import Database.PostgreSQL.Simple (SqlError(..))
|
|
import Database.PostgreSQL.Simple.Errors (isSerializationError)
|
|
import Control.Monad.Catch
|
|
|
|
import Database.Persist.Sql
|
|
import Database.Persist.Sql.Raw.QQ
|
|
|
|
import qualified Data.ByteString as ByteString
|
|
|
|
import Control.Retry
|
|
|
|
import Control.Lens ((&))
|
|
|
|
import qualified Data.UUID as UUID
|
|
import Control.Monad.Random.Class (MonadRandom(getRandom))
|
|
|
|
import Text.Shakespeare.Text (st)
|
|
|
|
import Control.Concurrent.Async (ExceptionInLinkedThread(..))
|
|
|
|
|
|
fromExceptionWrapped :: Exception exc => SomeException -> Maybe exc
|
|
fromExceptionWrapped (fromException -> Just exc) = Just exc
|
|
fromExceptionWrapped (fromException >=> \(ExceptionInLinkedThread _ exc') -> fromExceptionWrapped exc' -> Just exc) = Just exc
|
|
fromExceptionWrapped _ = Nothing
|
|
|
|
|
|
setSerializable :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a
|
|
setSerializable = setSerializable' $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 10e6
|
|
|
|
setSerializableBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a
|
|
setSerializableBatch = setSerializable' $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6
|
|
|
|
setSerializable' :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => RetryPolicyM (SqlPersistT m) -> SqlPersistT m a -> ReaderT SqlBackend m a
|
|
setSerializable' policy act = do
|
|
LogSettings{logSerializableTransactionRetryLimit} <- readLogSettings
|
|
didCommit <- newTVarIO False
|
|
recovering policy (skipAsyncExceptions `snoc` logRetries suggestRetry (logRetry logSerializableTransactionRetryLimit)) $ act' didCommit
|
|
where
|
|
suggestRetry :: SomeException -> ReaderT SqlBackend m Bool
|
|
suggestRetry = return . maybe False isSerializationError . fromExceptionWrapped
|
|
|
|
logRetry :: Maybe Natural
|
|
-> Bool -- ^ Will retry
|
|
-> SomeException
|
|
-> RetryStatus
|
|
-> ReaderT SqlBackend m ()
|
|
logRetry _ shouldRetry@False err status = $logErrorS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
|
|
logRetry (Just limit) shouldRetry err status
|
|
| fromIntegral limit <= rsIterNumber status = $logInfoS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
|
|
logRetry _ shouldRetry err status = $logDebugS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
|
|
|
|
act' :: TVar Bool -> RetryStatus -> ReaderT SqlBackend m a
|
|
act' didCommit RetryStatus{..} = do
|
|
prevCommited <- atomically $ swapTVar didCommit False
|
|
$logDebugS "SQL.setSerializable" $ "prevCommited = " <> tshow prevCommited <> "; rsIterNumber = " <> tshow rsIterNumber
|
|
if
|
|
| rsIterNumber == 0 -> [executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE|] *> act''
|
|
| prevCommited -> [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE|] *> act''
|
|
| otherwise -> transactionUndoWithIsolation Serializable *> act''
|
|
where act'' = do
|
|
res <- act
|
|
atomically $ writeTVar didCommit True
|
|
transactionSaveWithIsolation ReadCommitted
|
|
return res
|
|
|
|
catchSql :: forall m a. (MonadCatch m, MonadIO m) => SqlPersistT m a -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a
|
|
catchSql = flip handleSql
|
|
|
|
handleSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a
|
|
handleSql recover act = do
|
|
savepointName <- liftIO $ UUID.toString <$> getRandom
|
|
|
|
let recover' :: SomeException -> SqlPersistT m a
|
|
recover' (fromExceptionWrapped -> Just exc) = do
|
|
rawExecute [st|ROLLBACK TO SAVEPOINT "#{savepointName}"|] []
|
|
recover exc
|
|
recover' exc = throwM exc
|
|
|
|
handle recover' $ do
|
|
rawExecute [st|SAVEPOINT "#{savepointName}"|] []
|
|
res <- act
|
|
rawExecute [st|RELEASE SAVEPOINT "#{savepointName}"|] []
|
|
return res
|
|
|
|
catchIfSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> Bool) -> SqlPersistT m a -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a
|
|
catchIfSql p = flip $ handleIfSql p
|
|
|
|
handleIfSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> Bool) -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a
|
|
handleIfSql p recover = handleSql (\err -> bool throwM recover (p err) err)
|
|
|
|
isUniqueConstraintViolation :: SqlError -> Bool
|
|
isUniqueConstraintViolation SqlError{..} = "duplicate key value violates unique constraint" `ByteString.isPrefixOf` sqlErrorMsg
|