fradrive/src/Web/ServerSession/Backend/Persistent/Memcached.hs

182 lines
7.5 KiB
Haskell

-- SPDX-FileCopyrightText: 2022 Gregor Kleen <gregor.kleen@ifi.lmu.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE UndecidableInstances #-}
module Web.ServerSession.Backend.Persistent.Memcached
( migrateMemcachedSqlStorage
, MemcachedSessionExpirationId, MemcachedSessionExpiration(..)
, MemcachedSqlStorage(..)
, _mcdSqlConnPool, _mcdSqlMemcached, _mcdSqlMemcachedKey, _mcdSqlMemcachedExpiration
) where
import Import.NoModel hiding (AuthId, SessionMap, getSession)
import Utils.Lens
import Web.ServerSession.Core
import qualified Utils.Pool as Custom
import qualified Data.Binary as Binary
import qualified Database.Memcached.Binary.IO as Memcached
import qualified Crypto.Saltine.Class as Saltine
import qualified Crypto.Saltine.Internal.ByteSizes as Saltine
import qualified Crypto.Saltine.Core.AEAD as AEAD
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.Base64.URL as Base64
import Utils.Metrics (DBConnUseState)
import Data.Text.Encoding (decodeUtf8')
import Data.Bits (Bits(zeroBits))
share [mkPersist sqlSettings, mkMigrate "migrateMemcachedSqlStorage"]
[persistLowerCase|
MemcachedSessionExpiration
authId ByteString
time UTCTime
UniqueMemcachedSessionExpiration authId
deriving Show Eq Ord
|]
data MemcachedSqlStorage sess = MemcachedSqlStorage
{ mcdSqlConnPool :: forall m. MonadIO m => Custom.Pool' m DBConnLabel DBConnUseState SqlBackend
, mcdSqlMemcached :: Memcached.Connection
, mcdSqlMemcachedKey :: AEAD.Key
, mcdSqlMemcachedExpiration :: Maybe NominalDiffTime
}
makeLenses_ ''MemcachedSqlStorage
data MemcachedSqlStorageException
= MemcachedSqlStorageKeyCollision
| MemcachedSqlStorageAEADCiphertextTooShort
| MemcachedSqlStorageAEADCouldNotDecodeNonce
| MemcachedSqlStorageAEADCouldNotOpenAEAD
| MemcachedSqlStorageAEADCouldDecodeMemcachedSqlSession
deriving (Eq, Ord, Read, Show, Generic)
instance Exception MemcachedSqlStorageException
data MemcachedSqlSession sess = MemcachedSqlSession
{ mcdSqlSessionAuthId :: Maybe AuthId
, mcdSqlSessionData :: Decomposed sess
, mcdSqlSessionCreatedAt, mcdSqlSessionAccessedAt :: UTCTime
} deriving (Generic)
deriving instance Eq (Decomposed sess) => Eq (MemcachedSqlSession sess)
deriving instance Ord (Decomposed sess) => Ord (MemcachedSqlSession sess)
deriving instance Read (Decomposed sess) => Read (MemcachedSqlSession sess)
deriving instance Show (Decomposed sess) => Show (MemcachedSqlSession sess)
instance Binary (Decomposed sess) => Binary (MemcachedSqlSession sess)
instance Binary (SessionId sess) where
get = maybe (fail "Could not decode SessionId fromPathPiece") return . fromPathPiece . decodeUtf8 . Base64.encode . BS.pack =<< replicateM 18 Binary.get
put = mapM_ Binary.put . take 18 . BS.unpack . Base64.decodeLenient . encodeUtf8 . toPathPiece
memcachedSqlSession :: Iso' (SessionId sess, MemcachedSqlSession sess) (Session sess)
memcachedSqlSession = iso toSession fromSession
where
toSession (mcdSqlSessionKey, MemcachedSqlSession{..}) = Session
{ sessionKey = mcdSqlSessionKey
, sessionAuthId = mcdSqlSessionAuthId
, sessionData = mcdSqlSessionData
, sessionCreatedAt = mcdSqlSessionCreatedAt
, sessionAccessedAt = mcdSqlSessionAccessedAt
}
fromSession Session{..}
= ( sessionKey
, MemcachedSqlSession
{ mcdSqlSessionAuthId = sessionAuthId
, mcdSqlSessionData = sessionData
, mcdSqlSessionCreatedAt = sessionCreatedAt
, mcdSqlSessionAccessedAt = sessionAccessedAt
}
)
deriving newtype instance Binary SessionMap
memcachedSqlSessionId :: Prism' ByteString (SessionId dat)
memcachedSqlSessionId = prism' (encodeUtf8 . toPathPiece) (fromPathPiece <=< either (const Nothing) Just . decodeUtf8')
instance (IsSessionData sess, Binary (Decomposed sess)) => Storage (MemcachedSqlStorage sess) where
type SessionData (MemcachedSqlStorage sess) = sess
type TransactionM (MemcachedSqlStorage sess) = SqlPersistT IO
runTransactionM MemcachedSqlStorage{..} act = customRunSqlPool act mcdSqlConnPool
getSession MemcachedSqlStorage{..} sessId = exceptT (maybe (return Nothing) throwM) (return . Just) $ do
encSession <- catchIfExceptT (const Nothing) Memcached.isKeyNotFound . liftIO . fmap LBS.toStrict $ Memcached.getAndTouch_ expiry (memcachedSqlSessionId # sessId) mcdSqlMemcached
guardExceptT (BS.length encSession >= Saltine.secretBoxNonce + Saltine.secretBoxMac) $
Just MemcachedSqlStorageAEADCiphertextTooShort
let (nonceBS, encrypted) = BS.splitAt Saltine.secretBoxNonce encSession
encSessId = LBS.toStrict $ Binary.encode sessId
nonce <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldNotDecodeNonce) . hoistMaybe $ Saltine.decode nonceBS
decrypted <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldNotOpenAEAD) . hoistMaybe $ AEAD.aeadOpen mcdSqlMemcachedKey nonce encrypted encSessId
let binaryDecode bs = do
Right (unconsumed, _, res) <- return $ Binary.decodeOrFail bs
guard $ LBS.null unconsumed
return res
decoded@MemcachedSqlSession{..} <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldDecodeMemcachedSqlSession) . hoistMaybe . binaryDecode $ LBS.fromStrict decrypted
expiration <- runMaybeT $ fmap (memcachedSessionExpirationTime . entityVal) . MaybeT . lift . getBy . UniqueMemcachedSessionExpiration =<< hoistMaybe mcdSqlSessionAuthId
guardExceptT (maybe True (mcdSqlSessionCreatedAt >) expiration) Nothing
return $ (sessId, decoded) ^. memcachedSqlSession
where expiry = maybe 0 ceiling mcdSqlMemcachedExpiration
deleteSession MemcachedSqlStorage{..} sessId
= liftIO . handleIf Memcached.isKeyNotFound (const $ return ()) $ Memcached.delete (memcachedSqlSessionId # sessId) mcdSqlMemcached
deleteAllSessionsOfAuthId MemcachedSqlStorage{} authId = do
now <- liftIO getCurrentTime
void $ upsert
( MemcachedSessionExpiration authId now )
[ MemcachedSessionExpirationTime =. now
]
insertSession = replaceSession' False
replaceSession = replaceSession' True
replaceSession' :: forall sess.
( Storage (MemcachedSqlStorage sess)
, Binary (Decomposed sess)
)
=> Bool -- ^ Replace existing?
-> MemcachedSqlStorage sess
-> Session (SessionData (MemcachedSqlStorage sess))
-> SqlPersistT IO ()
replaceSession' isReplace s@MemcachedSqlStorage{..} seNewSession@(review memcachedSqlSession -> (sessId, decoded)) = do
unless isReplace $ do
mOld <- getSession @(MemcachedSqlStorage sess) s sessId
whenIsJust mOld $ \seExistingSession ->
throwM @_ @(StorageException (MemcachedSqlStorage sess)) $ SessionAlreadyExists{..}
nonce <- liftIO AEAD.newNonce
let encSession = Saltine.encode nonce <> AEAD.aead mcdSqlMemcachedKey nonce encoded encSessId
encSessId = LBS.toStrict $ Binary.encode sessId
handleFailure
= handleIf Memcached.isKeyExists (\_ -> throwM MemcachedSqlStorageKeyCollision)
. handleIf Memcached.isKeyNotFound (\_ -> throwM @_ @(StorageException (MemcachedSqlStorage sess)) SessionDoesNotExist{..})
handleFailure . liftIO $
bool Memcached.add Memcached.replace isReplace zeroBits expiry (memcachedSqlSessionId # sessId) (LBS.fromStrict encSession) mcdSqlMemcached
where
encoded = LBS.toStrict $ Binary.encode decoded
expiry = maybe 0 ceiling mcdSqlMemcachedExpiration