182 lines
7.5 KiB
Haskell
182 lines
7.5 KiB
Haskell
-- SPDX-FileCopyrightText: 2022 Gregor Kleen <gregor.kleen@ifi.lmu.de>
|
|
--
|
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
|
{-# LANGUAGE UndecidableInstances #-}
|
|
|
|
module Web.ServerSession.Backend.Persistent.Memcached
|
|
( migrateMemcachedSqlStorage
|
|
, MemcachedSessionExpirationId, MemcachedSessionExpiration(..)
|
|
, MemcachedSqlStorage(..)
|
|
, _mcdSqlConnPool, _mcdSqlMemcached, _mcdSqlMemcachedKey, _mcdSqlMemcachedExpiration
|
|
) where
|
|
|
|
import Import.NoModel hiding (AuthId, SessionMap, getSession)
|
|
|
|
import Utils.Lens
|
|
|
|
import Web.ServerSession.Core
|
|
|
|
import qualified Utils.Pool as Custom
|
|
|
|
import qualified Data.Binary as Binary
|
|
|
|
import qualified Database.Memcached.Binary.IO as Memcached
|
|
|
|
import qualified Crypto.Saltine.Class as Saltine
|
|
import qualified Crypto.Saltine.Internal.ByteSizes as Saltine
|
|
import qualified Crypto.Saltine.Core.AEAD as AEAD
|
|
|
|
import qualified Data.ByteString as BS
|
|
import qualified Data.ByteString.Lazy as LBS
|
|
|
|
import qualified Data.ByteString.Base64.URL as Base64
|
|
|
|
import Utils.Metrics (DBConnUseState)
|
|
|
|
|
|
import Data.Text.Encoding (decodeUtf8')
|
|
|
|
import Data.Bits (Bits(zeroBits))
|
|
|
|
|
|
share [mkPersist sqlSettings, mkMigrate "migrateMemcachedSqlStorage"]
|
|
[persistLowerCase|
|
|
MemcachedSessionExpiration
|
|
authId ByteString
|
|
time UTCTime
|
|
UniqueMemcachedSessionExpiration authId
|
|
deriving Show Eq Ord
|
|
|]
|
|
|
|
|
|
data MemcachedSqlStorage sess = MemcachedSqlStorage
|
|
{ mcdSqlConnPool :: forall m. MonadIO m => Custom.Pool' m DBConnLabel DBConnUseState SqlBackend
|
|
, mcdSqlMemcached :: Memcached.Connection
|
|
, mcdSqlMemcachedKey :: AEAD.Key
|
|
, mcdSqlMemcachedExpiration :: Maybe NominalDiffTime
|
|
}
|
|
makeLenses_ ''MemcachedSqlStorage
|
|
|
|
data MemcachedSqlStorageException
|
|
= MemcachedSqlStorageKeyCollision
|
|
| MemcachedSqlStorageAEADCiphertextTooShort
|
|
| MemcachedSqlStorageAEADCouldNotDecodeNonce
|
|
| MemcachedSqlStorageAEADCouldNotOpenAEAD
|
|
| MemcachedSqlStorageAEADCouldDecodeMemcachedSqlSession
|
|
deriving (Eq, Ord, Read, Show, Generic)
|
|
instance Exception MemcachedSqlStorageException
|
|
|
|
data MemcachedSqlSession sess = MemcachedSqlSession
|
|
{ mcdSqlSessionAuthId :: Maybe AuthId
|
|
, mcdSqlSessionData :: Decomposed sess
|
|
, mcdSqlSessionCreatedAt, mcdSqlSessionAccessedAt :: UTCTime
|
|
} deriving (Generic)
|
|
deriving instance Eq (Decomposed sess) => Eq (MemcachedSqlSession sess)
|
|
deriving instance Ord (Decomposed sess) => Ord (MemcachedSqlSession sess)
|
|
deriving instance Read (Decomposed sess) => Read (MemcachedSqlSession sess)
|
|
deriving instance Show (Decomposed sess) => Show (MemcachedSqlSession sess)
|
|
|
|
instance Binary (Decomposed sess) => Binary (MemcachedSqlSession sess)
|
|
instance Binary (SessionId sess) where
|
|
get = maybe (fail "Could not decode SessionId fromPathPiece") return . fromPathPiece . decodeUtf8 . Base64.encode . BS.pack =<< replicateM 18 Binary.get
|
|
put = mapM_ Binary.put . take 18 . BS.unpack . Base64.decodeLenient . encodeUtf8 . toPathPiece
|
|
|
|
memcachedSqlSession :: Iso' (SessionId sess, MemcachedSqlSession sess) (Session sess)
|
|
memcachedSqlSession = iso toSession fromSession
|
|
where
|
|
toSession (mcdSqlSessionKey, MemcachedSqlSession{..}) = Session
|
|
{ sessionKey = mcdSqlSessionKey
|
|
, sessionAuthId = mcdSqlSessionAuthId
|
|
, sessionData = mcdSqlSessionData
|
|
, sessionCreatedAt = mcdSqlSessionCreatedAt
|
|
, sessionAccessedAt = mcdSqlSessionAccessedAt
|
|
}
|
|
fromSession Session{..}
|
|
= ( sessionKey
|
|
, MemcachedSqlSession
|
|
{ mcdSqlSessionAuthId = sessionAuthId
|
|
, mcdSqlSessionData = sessionData
|
|
, mcdSqlSessionCreatedAt = sessionCreatedAt
|
|
, mcdSqlSessionAccessedAt = sessionAccessedAt
|
|
}
|
|
)
|
|
|
|
deriving newtype instance Binary SessionMap
|
|
|
|
|
|
memcachedSqlSessionId :: Prism' ByteString (SessionId dat)
|
|
memcachedSqlSessionId = prism' (encodeUtf8 . toPathPiece) (fromPathPiece <=< either (const Nothing) Just . decodeUtf8')
|
|
|
|
|
|
instance (IsSessionData sess, Binary (Decomposed sess)) => Storage (MemcachedSqlStorage sess) where
|
|
type SessionData (MemcachedSqlStorage sess) = sess
|
|
type TransactionM (MemcachedSqlStorage sess) = SqlPersistT IO
|
|
|
|
runTransactionM MemcachedSqlStorage{..} act = customRunSqlPool act mcdSqlConnPool
|
|
|
|
getSession MemcachedSqlStorage{..} sessId = exceptT (maybe (return Nothing) throwM) (return . Just) $ do
|
|
encSession <- catchIfExceptT (const Nothing) Memcached.isKeyNotFound . liftIO . fmap LBS.toStrict $ Memcached.getAndTouch_ expiry (memcachedSqlSessionId # sessId) mcdSqlMemcached
|
|
|
|
guardExceptT (BS.length encSession >= Saltine.secretBoxNonce + Saltine.secretBoxMac) $
|
|
Just MemcachedSqlStorageAEADCiphertextTooShort
|
|
let (nonceBS, encrypted) = BS.splitAt Saltine.secretBoxNonce encSession
|
|
encSessId = LBS.toStrict $ Binary.encode sessId
|
|
nonce <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldNotDecodeNonce) . hoistMaybe $ Saltine.decode nonceBS
|
|
decrypted <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldNotOpenAEAD) . hoistMaybe $ AEAD.aeadOpen mcdSqlMemcachedKey nonce encrypted encSessId
|
|
|
|
let binaryDecode bs = do
|
|
Right (unconsumed, _, res) <- return $ Binary.decodeOrFail bs
|
|
guard $ LBS.null unconsumed
|
|
return res
|
|
decoded@MemcachedSqlSession{..} <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldDecodeMemcachedSqlSession) . hoistMaybe . binaryDecode $ LBS.fromStrict decrypted
|
|
|
|
expiration <- runMaybeT $ fmap (memcachedSessionExpirationTime . entityVal) . MaybeT . lift . getBy . UniqueMemcachedSessionExpiration =<< hoistMaybe mcdSqlSessionAuthId
|
|
|
|
guardExceptT (maybe True (mcdSqlSessionCreatedAt >) expiration) Nothing
|
|
|
|
return $ (sessId, decoded) ^. memcachedSqlSession
|
|
|
|
where expiry = maybe 0 ceiling mcdSqlMemcachedExpiration
|
|
|
|
deleteSession MemcachedSqlStorage{..} sessId
|
|
= liftIO . handleIf Memcached.isKeyNotFound (const $ return ()) $ Memcached.delete (memcachedSqlSessionId # sessId) mcdSqlMemcached
|
|
|
|
deleteAllSessionsOfAuthId MemcachedSqlStorage{} authId = do
|
|
now <- liftIO getCurrentTime
|
|
void $ upsert
|
|
( MemcachedSessionExpiration authId now )
|
|
[ MemcachedSessionExpirationTime =. now
|
|
]
|
|
|
|
insertSession = replaceSession' False
|
|
replaceSession = replaceSession' True
|
|
|
|
replaceSession' :: forall sess.
|
|
( Storage (MemcachedSqlStorage sess)
|
|
, Binary (Decomposed sess)
|
|
)
|
|
=> Bool -- ^ Replace existing?
|
|
-> MemcachedSqlStorage sess
|
|
-> Session (SessionData (MemcachedSqlStorage sess))
|
|
-> SqlPersistT IO ()
|
|
replaceSession' isReplace s@MemcachedSqlStorage{..} seNewSession@(review memcachedSqlSession -> (sessId, decoded)) = do
|
|
unless isReplace $ do
|
|
mOld <- getSession @(MemcachedSqlStorage sess) s sessId
|
|
whenIsJust mOld $ \seExistingSession ->
|
|
throwM @_ @(StorageException (MemcachedSqlStorage sess)) $ SessionAlreadyExists{..}
|
|
|
|
nonce <- liftIO AEAD.newNonce
|
|
let encSession = Saltine.encode nonce <> AEAD.aead mcdSqlMemcachedKey nonce encoded encSessId
|
|
encSessId = LBS.toStrict $ Binary.encode sessId
|
|
handleFailure
|
|
= handleIf Memcached.isKeyExists (\_ -> throwM MemcachedSqlStorageKeyCollision)
|
|
. handleIf Memcached.isKeyNotFound (\_ -> throwM @_ @(StorageException (MemcachedSqlStorage sess)) SessionDoesNotExist{..})
|
|
handleFailure . liftIO $
|
|
bool Memcached.add Memcached.replace isReplace zeroBits expiry (memcachedSqlSessionId # sessId) (LBS.fromStrict encSession) mcdSqlMemcached
|
|
|
|
where
|
|
encoded = LBS.toStrict $ Binary.encode decoded
|
|
expiry = maybe 0 ceiling mcdSqlMemcachedExpiration
|