-- SPDX-FileCopyrightText: 2022 Gregor Kleen -- -- SPDX-License-Identifier: AGPL-3.0-or-later {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE UndecidableInstances #-} module Web.ServerSession.Backend.Persistent.Memcached ( migrateMemcachedSqlStorage , MemcachedSessionExpirationId, MemcachedSessionExpiration(..) , MemcachedSqlStorage(..) , _mcdSqlConnPool, _mcdSqlMemcached, _mcdSqlMemcachedKey, _mcdSqlMemcachedExpiration ) where import Import.NoModel hiding (AuthId, SessionMap, getSession) import Utils.Lens import Web.ServerSession.Core import qualified Utils.Pool as Custom import qualified Data.Binary as Binary import qualified Database.Memcached.Binary.IO as Memcached import qualified Crypto.Saltine.Class as Saltine import qualified Crypto.Saltine.Internal.ByteSizes as Saltine import qualified Crypto.Saltine.Core.AEAD as AEAD import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import qualified Data.ByteString.Base64.URL as Base64 import Utils.Metrics (DBConnUseState) import Data.Text.Encoding (decodeUtf8') import Data.Bits (Bits(zeroBits)) share [mkPersist sqlSettings, mkMigrate "migrateMemcachedSqlStorage"] [persistLowerCase| MemcachedSessionExpiration authId ByteString time UTCTime UniqueMemcachedSessionExpiration authId deriving Show Eq Ord |] data MemcachedSqlStorage sess = MemcachedSqlStorage { mcdSqlConnPool :: forall m. MonadIO m => Custom.Pool' m DBConnLabel DBConnUseState SqlBackend , mcdSqlMemcached :: Memcached.Connection , mcdSqlMemcachedKey :: AEAD.Key , mcdSqlMemcachedExpiration :: Maybe NominalDiffTime } makeLenses_ ''MemcachedSqlStorage data MemcachedSqlStorageException = MemcachedSqlStorageKeyCollision | MemcachedSqlStorageAEADCiphertextTooShort | MemcachedSqlStorageAEADCouldNotDecodeNonce | MemcachedSqlStorageAEADCouldNotOpenAEAD | MemcachedSqlStorageAEADCouldDecodeMemcachedSqlSession deriving (Eq, Ord, Read, Show, Generic) instance Exception MemcachedSqlStorageException data MemcachedSqlSession sess = MemcachedSqlSession { mcdSqlSessionAuthId :: Maybe AuthId , mcdSqlSessionData :: Decomposed sess , mcdSqlSessionCreatedAt, mcdSqlSessionAccessedAt :: UTCTime } deriving (Generic) deriving instance Eq (Decomposed sess) => Eq (MemcachedSqlSession sess) deriving instance Ord (Decomposed sess) => Ord (MemcachedSqlSession sess) deriving instance Read (Decomposed sess) => Read (MemcachedSqlSession sess) deriving instance Show (Decomposed sess) => Show (MemcachedSqlSession sess) instance Binary (Decomposed sess) => Binary (MemcachedSqlSession sess) instance Binary (SessionId sess) where get = maybe (fail "Could not decode SessionId fromPathPiece") return . fromPathPiece . decodeUtf8 . Base64.encode . BS.pack =<< replicateM 18 Binary.get put = mapM_ Binary.put . take 18 . BS.unpack . Base64.decodeLenient . encodeUtf8 . toPathPiece memcachedSqlSession :: Iso' (SessionId sess, MemcachedSqlSession sess) (Session sess) memcachedSqlSession = iso toSession fromSession where toSession (mcdSqlSessionKey, MemcachedSqlSession{..}) = Session { sessionKey = mcdSqlSessionKey , sessionAuthId = mcdSqlSessionAuthId , sessionData = mcdSqlSessionData , sessionCreatedAt = mcdSqlSessionCreatedAt , sessionAccessedAt = mcdSqlSessionAccessedAt } fromSession Session{..} = ( sessionKey , MemcachedSqlSession { mcdSqlSessionAuthId = sessionAuthId , mcdSqlSessionData = sessionData , mcdSqlSessionCreatedAt = sessionCreatedAt , mcdSqlSessionAccessedAt = sessionAccessedAt } ) deriving newtype instance Binary SessionMap memcachedSqlSessionId :: Prism' ByteString (SessionId dat) memcachedSqlSessionId = prism' (encodeUtf8 . toPathPiece) (fromPathPiece <=< either (const Nothing) Just . decodeUtf8') instance (IsSessionData sess, Binary (Decomposed sess)) => Storage (MemcachedSqlStorage sess) where type SessionData (MemcachedSqlStorage sess) = sess type TransactionM (MemcachedSqlStorage sess) = SqlPersistT IO runTransactionM MemcachedSqlStorage{..} act = customRunSqlPool act mcdSqlConnPool getSession MemcachedSqlStorage{..} sessId = exceptT (maybe (return Nothing) throwM) (return . Just) $ do encSession <- catchIfExceptT (const Nothing) Memcached.isKeyNotFound . liftIO . fmap LBS.toStrict $ Memcached.getAndTouch_ expiry (memcachedSqlSessionId # sessId) mcdSqlMemcached guardExceptT (BS.length encSession >= Saltine.secretBoxNonce + Saltine.secretBoxMac) $ Just MemcachedSqlStorageAEADCiphertextTooShort let (nonceBS, encrypted) = BS.splitAt Saltine.secretBoxNonce encSession encSessId = LBS.toStrict $ Binary.encode sessId nonce <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldNotDecodeNonce) . hoistMaybe $ Saltine.decode nonceBS decrypted <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldNotOpenAEAD) . hoistMaybe $ AEAD.aeadOpen mcdSqlMemcachedKey nonce encrypted encSessId let binaryDecode bs = do Right (unconsumed, _, res) <- return $ Binary.decodeOrFail bs guard $ LBS.null unconsumed return res decoded@MemcachedSqlSession{..} <- maybeTExceptT (Just MemcachedSqlStorageAEADCouldDecodeMemcachedSqlSession) . hoistMaybe . binaryDecode $ LBS.fromStrict decrypted expiration <- runMaybeT $ fmap (memcachedSessionExpirationTime . entityVal) . MaybeT . lift . getBy . UniqueMemcachedSessionExpiration =<< hoistMaybe mcdSqlSessionAuthId guardExceptT (maybe True (mcdSqlSessionCreatedAt >) expiration) Nothing return $ (sessId, decoded) ^. memcachedSqlSession where expiry = maybe 0 ceiling mcdSqlMemcachedExpiration deleteSession MemcachedSqlStorage{..} sessId = liftIO . handleIf Memcached.isKeyNotFound (const $ return ()) $ Memcached.delete (memcachedSqlSessionId # sessId) mcdSqlMemcached deleteAllSessionsOfAuthId MemcachedSqlStorage{} authId = do now <- liftIO getCurrentTime void $ upsert ( MemcachedSessionExpiration authId now ) [ MemcachedSessionExpirationTime =. now ] insertSession = replaceSession' False replaceSession = replaceSession' True replaceSession' :: forall sess. ( Storage (MemcachedSqlStorage sess) , Binary (Decomposed sess) ) => Bool -- ^ Replace existing? -> MemcachedSqlStorage sess -> Session (SessionData (MemcachedSqlStorage sess)) -> SqlPersistT IO () replaceSession' isReplace s@MemcachedSqlStorage{..} seNewSession@(review memcachedSqlSession -> (sessId, decoded)) = do unless isReplace $ do mOld <- getSession @(MemcachedSqlStorage sess) s sessId whenIsJust mOld $ \seExistingSession -> throwM @_ @(StorageException (MemcachedSqlStorage sess)) $ SessionAlreadyExists{..} nonce <- liftIO AEAD.newNonce let encSession = Saltine.encode nonce <> AEAD.aead mcdSqlMemcachedKey nonce encoded encSessId encSessId = LBS.toStrict $ Binary.encode sessId handleFailure = handleIf Memcached.isKeyExists (\_ -> throwM MemcachedSqlStorageKeyCollision) . handleIf Memcached.isKeyNotFound (\_ -> throwM @_ @(StorageException (MemcachedSqlStorage sess)) SessionDoesNotExist{..}) handleFailure . liftIO $ bool Memcached.add Memcached.replace isReplace zeroBits expiry (memcachedSqlSessionId # sessId) (LBS.fromStrict encSession) mcdSqlMemcached where encoded = LBS.toStrict $ Binary.encode decoded expiry = maybe 0 ceiling mcdSqlMemcachedExpiration