diff --git a/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs b/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs index 88f2a35..73c686a 100644 --- a/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs +++ b/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs @@ -18,6 +18,7 @@ module Web.ServerSession.Backend.Redis.Internal , getSessionImpl , deleteSessionImpl , removeSessionFromAuthId + , insertSessionForAuthId , deleteAllSessionsOfAuthIdImpl , insertSessionImpl , replaceSessionImpl @@ -185,9 +186,19 @@ deleteSessionImpl sid = do -- | Remove the given 'SessionId' from the set of sessions of the -- given 'AuthId'. Does not do anything if @Nothing@. removeSessionFromAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m () -removeSessionFromAuthId _ Nothing = return () -removeSessionFromAuthId sid (Just authId) = - void $ R.srem (rAuthKey authId) [rSessionKey sid] +removeSessionFromAuthId = fooSessionBarAuthId R.srem + +-- | Insert the given 'SessionId' into the set of sessions of the +-- given 'AuthId'. Does not do anything if @Nothing@. +insertSessionForAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m () +insertSessionForAuthId = fooSessionBarAuthId R.sadd + + +-- | (Internal) Helper for 'removeSessionFromAuthId' and 'insertSessionForAuthId' +fooSessionBarAuthId + :: R.RedisCtx m f => (ByteString -> [ByteString] -> m (f Integer)) -> SessionId -> Maybe AuthId -> m () +fooSessionBarAuthId _ _ Nothing = return () +fooSessionBarAuthId fun sid (Just authId) = void $ fun (rAuthKey authId) [rSessionKey sid] -- | Delete all sessions of the given auth ID. @@ -200,21 +211,40 @@ deleteAllSessionsOfAuthIdImpl authId = do -- | Insert a new session. insertSessionImpl :: Session -> R.Redis () insertSessionImpl session = do - transaction $ do - let sk = rSessionKey $ sessionKey session - r <- R.hmset sk (printSession session) - -- TODO: R.expireat - maybe (return ()) (\authId -> void $ R.sadd (rAuthKey authId) [sk]) $ sessionAuthId session - return (() <$ r) + -- Check that no old session exists. + let sid = sessionKey session + moldSession <- getSessionImpl sid + case moldSession of + Just oldSession -> liftIO $ E.throwIO $ SessionAlreadyExists oldSession session + Nothing -> do + transaction $ do + let sk = rSessionKey sid + r <- R.hmset sk (printSession session) + -- TODO: R.expireat + insertSessionForAuthId (sessionKey session) (sessionAuthId session) + return (() <$ r) -- | Replace the contents of a session. replaceSessionImpl :: Session -> R.Redis () replaceSessionImpl session = do - -- Remove the old auth ID from the map if it has changed. - oldSession <- getSessionImpl (sessionKey session) - let oldAuthId = sessionAuthId =<< oldSession - when (oldAuthId /= sessionAuthId session) $ - removeSessionFromAuthId (sessionKey session) oldAuthId - -- Otherwise the operation is the same as inserting. - insertSessionImpl session + -- Check that the old session exists. + let sid = sessionKey session + moldSession <- getSessionImpl sid + case moldSession of + Nothing -> liftIO $ E.throwIO $ SessionDoesNotExist session + Just oldSession -> do + transaction $ do + -- Delete the old session and set the new one. + let sk = rSessionKey sid + _ <- R.del [sk] + r <- R.hmset sk (printSession session) + + -- Remove the old auth ID from the map if it has changed. + let oldAuthId = sessionAuthId oldSession + newAuthId = sessionAuthId session + when (oldAuthId /= newAuthId) $ do + removeSessionFromAuthId sid oldAuthId + insertSessionForAuthId sid newAuthId + + return (() <$ r) diff --git a/serversession-backend-redis/tests/Main.hs b/serversession-backend-redis/tests/Main.hs index 4574d43..ffe74c5 100644 --- a/serversession-backend-redis/tests/Main.hs +++ b/serversession-backend-redis/tests/Main.hs @@ -5,8 +5,6 @@ import Test.Hspec import Web.ServerSession.Backend.Redis import Web.ServerSession.Core.StorageTests -import qualified Control.Exception as E - main :: IO () main = do conn <- connect defaultConnectInfo