-- SPDX-FileCopyrightText: 2022 Gregor Kleen -- -- SPDX-License-Identifier: AGPL-3.0-or-later module Utils.Sql ( setSerializable , setSerializableBatch, setSerializableReadOnlyBatch, setSerializableDeferrableBatch , SerializableMode(..), setSerializable' , catchSql, handleSql , isUniqueConstraintViolation , catchIfSql, handleIfSql , trySql ) where import ClassyPrelude.Yesod hiding (handle) import Numeric.Natural import Settings.Log import Database.PostgreSQL.Simple (SqlError(..)) import Database.PostgreSQL.Simple.Errors (isSerializationError) import Control.Monad.Catch import Database.Persist.Sql hiding (IsolationLevel(..)) import qualified Database.Persist.Sql as Persist (IsolationLevel(..)) import Database.Persist.Sql.Types.Instances () import Database.Persist.Sql.Raw.QQ import qualified Data.ByteString as ByteString import Control.Retry import Control.Lens ((&)) import qualified Data.UUID as UUID import Control.Monad.Random.Class (MonadRandom(getRandom)) import Text.Shakespeare.Text (st) import Control.Concurrent.Async (ExceptionInLinkedThread(..)) import Data.Universe import Control.Monad.Trans.Reader (withReaderT) fromExceptionWrapped :: Exception exc => SomeException -> Maybe exc fromExceptionWrapped (fromException -> Just exc) = Just exc fromExceptionWrapped (fromException >=> \(ExceptionInLinkedThread _ exc') -> fromExceptionWrapped exc' -> Just exc) = Just exc fromExceptionWrapped _ = Nothing data SerializableMode = Serializable | SerializableReadOnly | SerializableReadOnlyDeferrable deriving (Eq, Ord, Read, Show, Enum, Bounded, Generic) deriving anyclass (Universe, Finite) setSerializable :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a setSerializable = setSerializable' Serializable $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 10e6 setSerializableBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (SqlPersistT m)) => SqlPersistT m a -> SqlPersistT m a setSerializableBatch = setSerializable' Serializable $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6 setSerializableReadOnlyBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (ReaderT SqlReadBackend m)) => ReaderT SqlReadBackend m a -> ReaderT SqlReadBackend m a setSerializableReadOnlyBatch = setSerializable' SerializableReadOnly $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6 setSerializableDeferrableBatch :: forall m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (ReaderT SqlReadBackend m)) => ReaderT SqlReadBackend m a -> ReaderT SqlReadBackend m a setSerializableDeferrableBatch = setSerializable' SerializableReadOnlyDeferrable $ fullJitterBackoff 1e3 & limitRetriesByCumulativeDelay 3600e6 setSerializable' :: forall backend m a. (MonadLogger m, MonadMask m, MonadIO m, ReadLogSettings (ReaderT backend m), BackendCompatible SqlBackend backend) => SerializableMode -> RetryPolicyM (ReaderT backend m) -> ReaderT backend m a -> ReaderT backend m a setSerializable' mode policy act = do LogSettings{logSerializableTransactionRetryLimit} <- readLogSettings didCommit <- newTVarIO False recovering policy (skipAsyncExceptions `snoc` logRetries suggestRetry (logRetry logSerializableTransactionRetryLimit)) $ act' didCommit where suggestRetry :: SomeException -> ReaderT backend m Bool suggestRetry = return . maybe False isSerializationError . fromExceptionWrapped logRetry :: Maybe Natural -> Bool -- ^ Will retry -> SomeException -> RetryStatus -> ReaderT backend m () logRetry _ shouldRetry@False err status = $logErrorS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status logRetry (Just limit) shouldRetry err status | fromIntegral limit <= rsIterNumber status = $logInfoS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status logRetry _ shouldRetry err status = $logDebugS "SQL.setSerializable" . pack $ defaultLogMsg shouldRetry err status (setTransactionLevel, beginTransactionLevel) = case mode of Serializable -> ([executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE|], [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE|] ) SerializableReadOnly -> ([executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY|], [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY|] ) SerializableReadOnlyDeferrable -> ([executeQQ|SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE|], [executeQQ|BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE|]) act' :: TVar Bool -> RetryStatus -> ReaderT backend m a act' didCommit RetryStatus{..} = do prevCommited <- atomically $ swapTVar didCommit False $logDebugS "SQL.setSerializable" $ "prevCommited = " <> tshow prevCommited <> "; rsIterNumber = " <> tshow rsIterNumber if | rsIterNumber == 0 -> setTransactionLevel *> act'' | prevCommited -> beginTransactionLevel *> act'' | otherwise -> withReaderT projectBackend transactionUndo *> setTransactionLevel *> act'' where act'' = do res <- act atomically $ writeTVar didCommit True withReaderT projectBackend $ transactionSaveWithIsolation Persist.ReadCommitted return res catchSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => SqlPersistT m a -> (e -> SqlPersistT m a) -> SqlPersistT m a catchSql = flip handleSql handleSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => (e -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a handleSql recover act = do savepointName <- liftIO $ UUID.toString <$> getRandom let recover' :: SomeException -> SqlPersistT m a recover' (fromExceptionWrapped -> Just exc) = do rawExecute [st|ROLLBACK TO SAVEPOINT "#{savepointName}"|] [] recover exc recover' exc = throwM exc handle recover' $ do rawExecute [st|SAVEPOINT "#{savepointName}"|] [] res <- act rawExecute [st|RELEASE SAVEPOINT "#{savepointName}"|] [] return res catchIfSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => (e -> Bool) -> SqlPersistT m a -> (e -> SqlPersistT m a) -> SqlPersistT m a catchIfSql p = flip $ handleIfSql p handleIfSql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => (e -> Bool) -> (e -> SqlPersistT m a) -> SqlPersistT m a -> SqlPersistT m a handleIfSql p recover = handleSql (\err -> bool throwM recover (p err) err) trySql :: forall e m a. (MonadCatch m, MonadIO m, Exception e) => SqlPersistT m a -> SqlPersistT m (Either e a) trySql = handleSql (return . Left) . fmap Right isUniqueConstraintViolation :: SqlError -> Bool isUniqueConstraintViolation SqlError{..} = "duplicate key value violates unique constraint" `ByteString.isPrefixOf` sqlErrorMsg