module Utils.Sql ( setSerializable, setSerializable' , catchSql, handleSql , isUniqueConstraintViolation , catchIfSql, handleIfSql ) where import ClassyPrelude.Yesod hiding (handle) import Numeric.Natural import Settings.Log import Database.PostgreSQL.Simple (SqlError(..)) import Database.PostgreSQL.Simple.Errors (isSerializationError) import Control.Monad.Catch import Database.Persist.Sql import Database.Persist.Sql.Raw.QQ import qualified Data.ByteString as ByteString import Control.Retry import Control.Lens ((&)) import qualified Data.UUID as UUID import Control.Monad.Random.Class (MonadRandom(getRandom)) import Text.Shakespeare.Text (st) setSerializable :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a setSerializable = setSerializable' $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 10e6 setSerializable' :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => RetryPolicyM (SqlPersistT m) -> SqlPersistT m a -> ReaderT SqlBackend m a setSerializable' policy act = do LogSettings{logSerializableTransactionRetryLimit} <- readLogSettings didCommit <- newTVarIO False recovering policy (skipAsyncExceptions `snoc` logRetries suggestRetry (logRetry logSerializableTransactionRetryLimit)) $ act' didCommit where suggestRetry :: SqlError -> ReaderT SqlBackend m Bool suggestRetry = return . isSerializationError logRetry :: Maybe Natural -> Bool -- ^ Will retry -> SqlError -> RetryStatus -> ReaderT SqlBackend m () logRetry _ shouldRetry@False err status = $logErrorS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status logRetry (Just limit) shouldRetry err status | fromIntegral limit <= rsIterNumber status = $logInfoS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status logRetry _ shouldRetry err status = $logDebugS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status act' :: TVar Bool -> RetryStatus -> ReaderT SqlBackend m a act' didCommit RetryStatus{..} = do prevCommited <- atomically $ swapTVar didCommit False $logDebugS "SQL.setSerializable" $ "prevCommited = " <> tshow prevCommited <> "; rsIterNumber = " <> tshow rsIterNumber if | rsIterNumber == 0 -> [executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE|] *> act'' | prevCommited -> [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE|] *> act'' | otherwise -> transactionUndoWithIsolation Serializable *> act'' where act'' = do res <- act atomically $ writeTVar didCommit True transactionSaveWithIsolation ReadCommitted return res catchSql :: forall m a. (MonadCatch m, MonadIO m) => SqlPersistT m a -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a catchSql = flip handleSql handleSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a handleSql recover act = do savepointName <- liftIO $ UUID.toString <$> getRandom let recover' :: SqlError -> SqlPersistT m a recover' exc = do rawExecute [st|ROLLBACK TO SAVEPOINT "#{savepointName}"|] [] recover exc handle recover' $ do rawExecute [st|SAVEPOINT "#{savepointName}"|] [] res <- act rawExecute [st|RELEASE SAVEPOINT "#{savepointName}"|] [] return res catchIfSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> Bool) -> SqlPersistT m a -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a catchIfSql p = flip $ handleIfSql p handleIfSql :: forall m a. (MonadCatch m, MonadIO m) => (SqlError -> Bool) -> (SqlError -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a handleIfSql p recover = handleSql (\err -> bool throwM recover (p err) err) isUniqueConstraintViolation :: SqlError -> Bool isUniqueConstraintViolation SqlError{..} = "duplicate key value violates unique constraint" `ByteString.isPrefixOf` sqlErrorMsg