fradrive/src/Utils/Sql.hs
2020-09-22 14:14:25 +02:00

105 lines
4.7 KiB
Haskell

module Utils.Sql
( setSerializable, setSerializableBatch, setSerializable'
, catchSql, handleSql
, isUniqueConstraintViolation
, catchIfSql, handleIfSql
) where
import ClassyPrelude.Yesod hiding (handle)
import Numeric.Natural
import Settings.Log
import Database.PostgreSQL.Simple (SqlError(..))
import Database.PostgreSQL.Simple.Errors (isSerializationError)
import Control.Monad.Catch
import Database.Persist.Sql
import Database.Persist.Sql.Raw.QQ
import qualified Data.ByteString as ByteString
import Control.Retry
import Control.Lens ((&))
import qualified Data.UUID as UUID
import Control.Monad.Random.Class (MonadRandom(getRandom))
import Text.Shakespeare.Text (st)
import Control.Concurrent.Async (ExceptionInLinkedThread(..))
fromExceptionWrapped :: Exception exc => SomeException -> Maybe exc
fromExceptionWrapped (fromException -> Just exc) = Just exc
fromExceptionWrapped (fromException >=> \(ExceptionInLinkedThread _ exc') -> fromExceptionWrapped exc' -> Just exc) = Just exc
fromExceptionWrapped _ = Nothing
setSerializable :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a
setSerializable = setSerializable' $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 10e6
setSerializableBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a
setSerializableBatch = setSerializable' $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6
setSerializable' :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => RetryPolicyM (SqlPersistT m) -> SqlPersistT m a -> ReaderT SqlBackend m a
setSerializable' policy act = do
LogSettings{logSerializableTransactionRetryLimit} <- readLogSettings
didCommit <- newTVarIO False
recovering policy (skipAsyncExceptions `snoc` logRetries suggestRetry (logRetry logSerializableTransactionRetryLimit)) $ act' didCommit
where
suggestRetry :: SomeException -> ReaderT SqlBackend m Bool
suggestRetry = return . maybe False isSerializationError . fromExceptionWrapped
logRetry :: Maybe Natural
-> Bool -- ^ Will retry
-> SomeException
-> RetryStatus
-> ReaderT SqlBackend m ()
logRetry _ shouldRetry@False err status = $logErrorS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
logRetry (Just limit) shouldRetry err status
| fromIntegral limit <= rsIterNumber status = $logInfoS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
logRetry _ shouldRetry err status = $logDebugS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status
act' :: TVar Bool -> RetryStatus -> ReaderT SqlBackend m a
act' didCommit RetryStatus{..} = do
prevCommited <- atomically $ swapTVar didCommit False
$logDebugS "SQL.setSerializable" $ "prevCommited = " <> tshow prevCommited <> "; rsIterNumber = " <> tshow rsIterNumber
if
| rsIterNumber == 0 -> [executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE|] *> act''
| prevCommited -> [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE|] *> act''
| otherwise -> transactionUndoWithIsolation Serializable *> act''
where act'' = do
res <- act
atomically $ writeTVar didCommit True
transactionSaveWithIsolation ReadCommitted
return res
catchSql :: forall m a. (MonadCatch m, MonadIO m) => SqlPersistT m a -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a
catchSql = flip handleSql
handleSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a
handleSql recover act = do
savepointName <- liftIO $ UUID.toString <$> getRandom
let recover' :: SomeException -> SqlPersistT m a
recover' (fromExceptionWrapped -> Just exc) = do
rawExecute [st|ROLLBACK TO SAVEPOINT "#{savepointName}"|] []
recover exc
recover' exc = throwM exc
handle recover' $ do
rawExecute [st|SAVEPOINT "#{savepointName}"|] []
res <- act
rawExecute [st|RELEASE SAVEPOINT "#{savepointName}"|] []
return res
catchIfSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> Bool) -> SqlPersistT m a -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a
catchIfSql p = flip $ handleIfSql p
handleIfSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> Bool) -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a
handleIfSql p recover = handleSql (\err -> bool throwM recover (p err) err)
isUniqueConstraintViolation :: SqlError -> Bool
isUniqueConstraintViolation SqlError{..} = "duplicate key value violates unique constraint" `ByteString.isPrefixOf` sqlErrorMsg