Use HashMap for SessionMap.

It's a bit faster and uses a bit less memory.
This commit is contained in:
Felipe Lessa 2015-05-31 23:09:02 -03:00
parent e127371df6
commit 0cca9cd086
17 changed files with 92 additions and 81 deletions

View File

@ -21,6 +21,7 @@ library
, containers , containers
, mtl , mtl
, safecopy == 0.8.* , safecopy == 0.8.*
, unordered-containers
, serversession == 1.0.* , serversession == 1.0.*
exposed-modules: exposed-modules:
@ -40,7 +41,7 @@ test-suite tests
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
hs-source-dirs: tests hs-source-dirs: tests
build-depends: build-depends:
base, acid-state, containers, mtl, safecopy base, acid-state, containers, mtl, safecopy, unordered-containers
, hspec >= 2.1 && < 3 , hspec >= 2.1 && < 3

View File

@ -33,6 +33,7 @@ import Data.SafeCopy
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
import qualified Control.Exception as E import qualified Control.Exception as E
import qualified Data.HashMap.Strict as HM
import qualified Data.Map.Strict as M import qualified Data.Map.Strict as M
import qualified Data.Set as S import qualified Data.Set as S
import qualified Web.ServerSession.Core as SS import qualified Web.ServerSession.Core as SS
@ -100,7 +101,11 @@ nothingfy s = if S.null s then Nothing else Just s
---------------------------------------------------------------------- ----------------------------------------------------------------------
deriveSafeCopy 0 'base ''SS.SessionMap -- | We can't @deriveSafeCopy 0 'base ''SS.SessionMap@ because
-- @safeCopy@ doesn't contain instances for @HashMap@ as of now.
instance SafeCopy SS.SessionMap where
putCopy = contain . safePut . HM.toList . SS.unSessionMap
getCopy = contain $ SS.SessionMap . HM.fromList <$> safeGet
-- | We can't @deriveSafeCopy 0 'base ''SS.SessionId@ as -- | We can't @deriveSafeCopy 0 'base ''SS.SessionId@ as

View File

@ -21,13 +21,13 @@ library
, base64-bytestring == 1.0.* , base64-bytestring == 1.0.*
, bytestring , bytestring
, cereal >= 0.4 , cereal >= 0.4
, containers
, path-pieces , path-pieces
, persistent == 2.1.* , persistent == 2.1.*
, tagged >= 0.8 , tagged >= 0.8
, text , text
, time , time
, transformers , transformers
, unordered-containers
, serversession == 1.0.* , serversession == 1.0.*
exposed-modules: exposed-modules:
@ -58,9 +58,9 @@ test-suite tests
hs-source-dirs: tests hs-source-dirs: tests
build-depends: build-depends:
base, aeson, base64-bytestring, bytestring, cereal, containers, base, aeson, base64-bytestring, bytestring, cereal,
path-pieces, persistent, persistent-template, text, time, path-pieces, persistent, persistent-template, text, time,
transformers transformers, unordered-containers
, hspec >= 2.1 && < 3 , hspec >= 2.1 && < 3
, monad-logger , monad-logger

View File

@ -25,7 +25,6 @@ import Web.ServerSession.Core
import qualified Control.Exception as E import qualified Control.Exception as E
import qualified Data.Aeson as A import qualified Data.Aeson as A
import qualified Data.Map as M
import qualified Data.Text as T import qualified Data.Text as T
import qualified Database.Persist as P import qualified Database.Persist as P
import qualified Database.Persist.Sql as P import qualified Database.Persist.Sql as P
@ -95,7 +94,7 @@ instance forall sess. P.PersistFieldSql (Decomposed sess) => P.PersistEntity (Pe
[] []
[] []
["Eq", "Ord", "Show", "Typeable"] ["Eq", "Ord", "Show", "Typeable"]
M.empty mempty
False False
where where
pfd :: P.EntityField (PersistentSession sess) typ -> P.FieldDef pfd :: P.EntityField (PersistentSession sess) typ -> P.FieldDef

View File

@ -25,7 +25,7 @@ import Web.ServerSession.Core.Internal (SessionId(..))
import qualified Data.Aeson as A import qualified Data.Aeson as A
import qualified Data.ByteString.Base64.URL as B64URL import qualified Data.ByteString.Base64.URL as B64URL
import qualified Data.Map as M import qualified Data.HashMap.Strict as HM
import qualified Data.Serialize as S import qualified Data.Serialize as S
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
@ -102,17 +102,17 @@ instance PersistFieldSql SessionMap where
sqlType _ = SqlBlob sqlType _ = SqlBlob
instance S.Serialize SessionMap where instance S.Serialize SessionMap where
put = S.put . map (first TE.encodeUtf8) . M.toAscList . unSessionMap put = S.put . map (first TE.encodeUtf8) . HM.toList . unSessionMap
get = SessionMap . M.fromAscList . map (first TE.decodeUtf8) <$> S.get get = SessionMap . HM.fromList . map (first TE.decodeUtf8) <$> S.get
instance A.FromJSON SessionMap where instance A.FromJSON SessionMap where
parseJSON = fmap fixup . A.parseJSON parseJSON = fmap fixup . A.parseJSON
where where
fixup :: M.Map Text ByteStringJ -> SessionMap fixup :: HM.HashMap Text ByteStringJ -> SessionMap
fixup = SessionMap . fmap unB fixup = SessionMap . fmap unB
instance A.ToJSON SessionMap where instance A.ToJSON SessionMap where
toJSON = A.toJSON . mangle toJSON = A.toJSON . mangle
where where
mangle :: SessionMap -> M.Map Text ByteStringJ mangle :: SessionMap -> HM.HashMap Text ByteStringJ
mangle = fmap B . unSessionMap mangle = fmap B . unSessionMap

View File

@ -18,13 +18,13 @@ library
build-depends: build-depends:
base == 4.* base == 4.*
, bytestring , bytestring
, containers
, hedis == 0.6.* , hedis == 0.6.*
, path-pieces , path-pieces
, tagged >= 0.8 , tagged >= 0.8
, text , text
, time >= 1.5 , time >= 1.5
, transformers , transformers
, unordered-containers
, serversession == 1.0.* , serversession == 1.0.*
exposed-modules: exposed-modules:
@ -44,8 +44,8 @@ test-suite tests
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
hs-source-dirs: tests hs-source-dirs: tests
build-depends: build-depends:
base, bytestring, containers, hedis, path-pieces, text, base, bytestring, hedis, path-pieces, text,
time, transformers time, transformers, unordered-containers
, hspec >= 2.1 && < 3 , hspec >= 2.1 && < 3

View File

@ -42,7 +42,7 @@ import qualified Control.Exception as E
import qualified Database.Redis as R import qualified Database.Redis as R
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.Map.Strict as M import qualified Data.HashMap.Strict as HM
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import qualified Data.Time.Clock as TI import qualified Data.Time.Clock as TI
import qualified Data.Time.Format as TI import qualified Data.Time.Format as TI
@ -141,8 +141,8 @@ class IsSessionData sess => RedisSession sess where
-- | Assumes that keys are UTF-8 encoded when parsing (which is -- | Assumes that keys are UTF-8 encoded when parsing (which is
-- true if keys are always generated via @toHash@). -- true if keys are always generated via @toHash@).
instance RedisSession SessionMap where instance RedisSession SessionMap where
toHash _ = map (first TE.encodeUtf8) . M.toList . unSessionMap toHash _ = map (first TE.encodeUtf8) . HM.toList . unSessionMap
fromHash _ = SessionMap . M.fromList . map (first TE.decodeUtf8) fromHash _ = SessionMap . HM.fromList . map (first TE.decodeUtf8)
-- | Parse a 'Session' from a Redis hash. -- | Parse a 'Session' from a Redis hash.

View File

@ -18,7 +18,6 @@ library
build-depends: build-depends:
base == 4.* base == 4.*
, bytestring , bytestring
, containers
, nonce , nonce
, path-pieces , path-pieces
, snap == 0.14.* , snap == 0.14.*
@ -26,6 +25,7 @@ library
, text , text
, time , time
, transformers , transformers
, unordered-containers
, serversession == 1.0.* , serversession == 1.0.*
exposed-modules: exposed-modules:

View File

@ -24,10 +24,10 @@ import Web.ServerSession.Core
import qualified Crypto.Nonce as N import qualified Crypto.Nonce as N
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import qualified Data.Time as TI import qualified Data.Time as TI
import qualified Data.Map as M
import qualified Snap.Core as S import qualified Snap.Core as S
import qualified Snap.Snaplet as S import qualified Snap.Snaplet as S
import qualified Snap.Snaplet.Session as S import qualified Snap.Snaplet.Session as S
@ -84,27 +84,27 @@ class IsSessionData sess => SnapSession sess where
-- | Uses 'csrfKey'. -- | Uses 'csrfKey'.
instance SnapSession SessionMap where instance SnapSession SessionMap where
ssInsert key val = onSM (M.insert key (TE.encodeUtf8 val)) ssInsert key val = onSM (HM.insert key (TE.encodeUtf8 val))
ssLookup key = fmap TE.decodeUtf8 . M.lookup key . unSessionMap ssLookup key = fmap TE.decodeUtf8 . HM.lookup key . unSessionMap
ssDelete key = onSM (M.delete key) ssDelete key = onSM (HM.delete key)
ssToList = ssToList =
-- Remove the CSRF key from the list as the current -- Remove the CSRF key from the list as the current
-- clientsession backend doesn't return it. -- clientsession backend doesn't return it.
fmap (second TE.decodeUtf8) . fmap (second TE.decodeUtf8) .
M.toList . HM.toList .
M.delete csrfKey . HM.delete csrfKey .
unSessionMap unSessionMap
ssInsertCsrf = ssInsert csrfKey ssInsertCsrf = ssInsert csrfKey
ssLookupCsrf = ssLookup csrfKey ssLookupCsrf = ssLookup csrfKey
ssForceInvalidate force = onSM (M.insert forceInvalidateKey (B8.pack $ show force)) ssForceInvalidate force = onSM (HM.insert forceInvalidateKey (B8.pack $ show force))
-- | Apply a function to a 'SessionMap'. -- | Apply a function to a 'SessionMap'.
onSM onSM
:: (M.Map Text ByteString -> M.Map Text ByteString) :: (HM.HashMap Text ByteString -> HM.HashMap Text ByteString)
-> (SessionMap -> SessionMap) -> (SessionMap -> SessionMap)
onSM f = SessionMap . f . unSessionMap onSM f = SessionMap . f . unSessionMap

View File

@ -18,13 +18,13 @@ library
build-depends: build-depends:
base >= 4.6 && < 5 base >= 4.6 && < 5
, bytestring , bytestring
, containers
, cookie >= 0.4 , cookie >= 0.4
, data-default , data-default
, path-pieces , path-pieces
, text , text
, time , time
, transformers , transformers
, unordered-containers
, vault , vault
, wai , wai
, wai-session == 0.3.* , wai-session == 0.3.*

View File

@ -21,8 +21,8 @@ import Web.ServerSession.Core
import Web.ServerSession.Core.Internal (absoluteTimeout, idleTimeout, persistentCookies) import Web.ServerSession.Core.Internal (absoluteTimeout, idleTimeout, persistentCookies)
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.IORef as I import qualified Data.IORef as I
import qualified Data.Map as M
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import qualified Data.Time as TI import qualified Data.Time as TI
import qualified Data.Vault.Lazy as V import qualified Data.Vault.Lazy as V
@ -100,8 +100,8 @@ class IsSessionData sess => KeyValue sess where
instance KeyValue SessionMap where instance KeyValue SessionMap where
type Key SessionMap = Text type Key SessionMap = Text
type Value SessionMap = ByteString type Value SessionMap = ByteString
kvLookup k = M.lookup k . unSessionMap kvLookup k = HM.lookup k . unSessionMap
kvInsert k v (SessionMap m) = SessionMap (M.insert k v m) kvInsert k v (SessionMap m) = SessionMap (HM.insert k v m)
---------------------------------------------------------------------- ----------------------------------------------------------------------

View File

@ -25,6 +25,7 @@ library
, text , text
, time , time
, transformers , transformers
, unordered-containers
, wai , wai
, yesod-core == 1.4.* , yesod-core == 1.4.*

View File

@ -21,6 +21,7 @@ import Yesod.Core.Handler (setSessionBS)
import Yesod.Core.Types (Header(AddCookie), SessionBackend(..)) import Yesod.Core.Types (Header(AddCookie), SessionBackend(..))
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.Map as M import qualified Data.Map as M
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import qualified Data.Time as TI import qualified Data.Time as TI
@ -107,8 +108,8 @@ class IsSessionMap sess where
instance IsSessionMap SessionMap where instance IsSessionMap SessionMap where
toSessionMap = unSessionMap toSessionMap = M.fromList . HM.toList . unSessionMap
fromSessionMap = SessionMap fromSessionMap = SessionMap . HM.fromList . M.toList
---------------------------------------------------------------------- ----------------------------------------------------------------------

View File

@ -20,13 +20,14 @@ library
, aeson , aeson
, base64-bytestring == 1.0.* , base64-bytestring == 1.0.*
, bytestring , bytestring
, containers
, data-default , data-default
, hashable
, nonce == 1.0.* , nonce == 1.0.*
, path-pieces , path-pieces
, text , text
, time , time
, transformers , transformers
, unordered-containers
exposed-modules: exposed-modules:
Web.ServerSession.Core Web.ServerSession.Core
Web.ServerSession.Core.Internal Web.ServerSession.Core.Internal
@ -47,9 +48,11 @@ test-suite tests
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
hs-source-dirs: tests hs-source-dirs: tests
build-depends: build-depends:
base, aeson, base64-bytestring, bytestring, containers, base, aeson, base64-bytestring, bytestring, data-default,
data-default, nonce, path-pieces, text, time, transformers nonce, path-pieces, text, time, transformers,
unordered-containers
, containers
, hspec >= 2.1 && < 3 , hspec >= 2.1 && < 3
, QuickCheck , QuickCheck
, serversession , serversession

View File

@ -49,6 +49,7 @@ import Control.Applicative ((<$>), (<*>))
import Control.Monad (guard, when) import Control.Monad (guard, when)
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Hashable (Hashable(..))
import Data.Maybe (catMaybes, fromMaybe, isJust) import Data.Maybe (catMaybes, fromMaybe, isJust)
import Data.Text (Text) import Data.Text (Text)
import Data.Time (UTCTime, getCurrentTime) import Data.Time (UTCTime, getCurrentTime)
@ -61,7 +62,7 @@ import qualified Crypto.Nonce as N
import qualified Data.Aeson as A import qualified Data.Aeson as A
import qualified Data.ByteString.Base64.URL as B64URL import qualified Data.ByteString.Base64.URL as B64URL
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.Map as M import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
@ -93,6 +94,9 @@ instance A.FromJSON (SessionId sess) where
instance A.ToJSON (SessionId sess) where instance A.ToJSON (SessionId sess) where
toJSON = A.toJSON . unS toJSON = A.toJSON . unS
instance Hashable (SessionId sess) where
hashWithSalt s = hashWithSalt s . unS
-- | (Internal) Check that the given text is a base64url-encoded -- | (Internal) Check that the given text is a base64url-encoded
-- representation of 18 bytes. -- representation of 18 bytes.
@ -150,8 +154,8 @@ deriving instance Show (Decomposed sess) => Show (Session sess)
-- to support this session data type on all frontends and storage -- to support this session data type on all frontends and storage
-- backends. -- backends.
newtype SessionMap = newtype SessionMap =
SessionMap { unSessionMap :: M.Map Text ByteString } SessionMap { unSessionMap :: HM.HashMap Text ByteString }
deriving (Eq, Ord, Show, Read, Typeable) deriving (Eq, Show, Read, Typeable)
---------------------------------------------------------------------- ----------------------------------------------------------------------
@ -224,22 +228,25 @@ class ( Show (Decomposed sess)
instance IsSessionData SessionMap where instance IsSessionData SessionMap where
type Decomposed SessionMap = SessionMap type Decomposed SessionMap = SessionMap
emptySession = SessionMap M.empty emptySession = SessionMap HM.empty
isSameDecomposed _ = (==) isSameDecomposed _ = (==)
decomposeSession authKey_ (SessionMap sm1) = decomposeSession authKey_ (SessionMap sm1) =
let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) authKey_ sm1 let authId = HM.lookup authKey_ sm1
(force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2 force = maybe DoNotForceInvalidate (read . B8.unpack) $
HM.lookup forceInvalidateKey sm1
sm2 = HM.delete authKey_ $
HM.delete forceInvalidateKey sm1
in DecomposedSession in DecomposedSession
{ dsAuthId = authId { dsAuthId = authId
, dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force , dsForceInvalidate = force
, dsDecomposed = SessionMap sm3 } , dsDecomposed = SessionMap sm2 }
recomposeSession authKey_ mauthId (SessionMap sm) = recomposeSession authKey_ mauthId (SessionMap sm) =
SessionMap $ maybe id (M.insert authKey_) mauthId sm SessionMap $ maybe id (HM.insert authKey_) mauthId sm
isDecomposedEmpty _ = M.null . unSessionMap isDecomposedEmpty _ = HM.null . unSessionMap
-- | A session data type @sess@ with its special variables taken apart. -- | A session data type @sess@ with its special variables taken apart.

View File

@ -14,7 +14,7 @@ import Web.ServerSession.Core.Internal
import qualified Crypto.Nonce as N import qualified Crypto.Nonce as N
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.Map as M import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Time as TI import qualified Data.Time as TI
@ -172,11 +172,11 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
let session = Session let session = Session
{ sessionKey = sid { sessionKey = sid
, sessionAuthId = Nothing , sessionAuthId = Nothing
, sessionData = SessionMap $ M.fromList vals , sessionData = SessionMap $ HM.fromList vals
, sessionCreatedAt = now , sessionCreatedAt = now
, sessionAccessedAt = now , sessionAccessedAt = now
} }
ver2 = session { sessionData = SessionMap M.empty } ver2 = session { sessionData = SessionMap HM.empty }
run (getSession storage sid) `shouldReturn` Nothing run (getSession storage sid) `shouldReturn` Nothing
run (insertSession storage session) run (insertSession storage session)
run (getSession storage sid) `shouldReturn` (Just session) run (getSession storage sid) `shouldReturn` (Just session)
@ -217,7 +217,7 @@ generateSession gen hasAuthId = do
data_ <- do data_ <- do
keys <- replicateM 20 (N.nonce128urlT gen) keys <- replicateM 20 (N.nonce128urlT gen)
vals <- replicateM 20 (N.nonce128url gen) vals <- replicateM 20 (N.nonce128url gen)
return $ M.fromList (zip keys vals) return $ HM.fromList (zip keys vals)
now <- TI.getCurrentTime now <- TI.getCurrentTime
return Session return Session
{ sessionKey = sid { sessionKey = sid

View File

@ -14,8 +14,8 @@ import Web.ServerSession.Core.StorageTests
import qualified Control.Exception as E import qualified Control.Exception as E
import qualified Crypto.Nonce as N import qualified Crypto.Nonce as N
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.IORef as I import qualified Data.IORef as I
import qualified Data.Map as M
import qualified Data.Set as S import qualified Data.Set as S
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Time as TI import qualified Data.Time as TI
@ -128,7 +128,7 @@ main = hspec $ parallel $ do
st <- createState =<< prepareMockStorage [session] st <- createState =<< prepareMockStorage [session]
(retSessionMap, SaveSessionToken msession _now) <- (retSessionMap, SaveSessionToken msession _now) <-
loadSession st (Just $ B8.pack $ T.unpack $ unS $ sessionKey session) loadSession st (Just $ B8.pack $ T.unpack $ unS $ sessionKey session)
retSessionMap `shouldBe` onSM (M.insert (authKey st) authId) (sessionData session) retSessionMap `shouldBe` onSM (HM.insert (authKey st) authId) (sessionData session)
msession `shouldBe` Just session msession `shouldBe` Just session
describe "checkExpired" $ do describe "checkExpired" $ do
@ -224,22 +224,22 @@ main = hspec $ parallel $ do
sessionData session1 `shouldBe` m1 sessionData session1 `shouldBe` m1
getMockOperations sto `shouldReturn` [InsertSession session1] getMockOperations sto `shouldReturn` [InsertSession session1]
let m2 = onSM (M.insert (authKey st) "john") m1 let m2 = onSM (HM.insert (authKey st) "john") m1
Just session2 <- saveSession st (SaveSessionToken (Just session1) fakenow) m2 Just session2 <- saveSession st (SaveSessionToken (Just session1) fakenow) m2
sessionAuthId session2 `shouldBe` Just "john" sessionAuthId session2 `shouldBe` Just "john"
sessionData session2 `shouldBe` m1 sessionData session2 `shouldBe` m1
sessionKey session2 == sessionKey session1 `shouldBe` False sessionKey session2 == sessionKey session1 `shouldBe` False
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session1), InsertSession session2] getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session1), InsertSession session2]
let m3 = onSM (M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser)) m2 let m3 = onSM (HM.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser)) m2
Just session3 <- saveSession st (SaveSessionToken (Just session2) fakenow) m3 Just session3 <- saveSession st (SaveSessionToken (Just session2) fakenow) m3
session3 `shouldBe` session2 { sessionKey = sessionKey session3 } session3 `shouldBe` session2 { sessionKey = sessionKey session3 }
getMockOperations sto `shouldReturn` getMockOperations sto `shouldReturn`
[DeleteSession (sessionKey session2), DeleteAllSessionsOfAuthId "john", InsertSession session3] [DeleteSession (sessionKey session2), DeleteAllSessionsOfAuthId "john", InsertSession session3]
let m4 = onSM (M.insert "x" "y") m2 let m4 = onSM (HM.insert "x" "y") m2
Just session4 <- saveSession st (SaveSessionToken (Just session3) fakenow) m4 Just session4 <- saveSession st (SaveSessionToken (Just session3) fakenow) m4
session4 `shouldBe` session3 { sessionData = onSM (M.delete (authKey st)) m4 } session4 `shouldBe` session3 { sessionData = onSM (HM.delete (authKey st)) m4 }
getMockOperations sto `shouldReturn` [ReplaceSession session4] getMockOperations sto `shouldReturn` [ReplaceSession session4]
Just session5 <- saveSession st (SaveSessionToken (Just session4) (TI.addUTCTime 10 fakenow)) m4 Just session5 <- saveSession st (SaveSessionToken (Just session4) (TI.addUTCTime 10 fakenow)) m4
@ -369,13 +369,13 @@ main = hspec $ parallel $ do
prop "parses the force invalidate key" $ prop "parses the force invalidate key" $
\data_ -> \data_ ->
let sessionMap v = onSM (M.insert forceInvalidateKey (B8.pack $ show v)) $ mkSessionMap data_ let sessionMap v = onSM (HM.insert forceInvalidateKey (B8.pack $ show v)) $ mkSessionMap data_
allForces = [minBound..maxBound] :: [ForceInvalidate] allForces = [minBound..maxBound] :: [ForceInvalidate]
test v = dsForceInvalidate (decomposeSession authKey_ $ sessionMap v) Q.=== v test v = dsForceInvalidate (decomposeSession authKey_ $ sessionMap v) Q.=== v
in Q.conjoin (test <$> allForces) in Q.conjoin (test <$> allForces)
it "removes the auth key" $ do it "removes the auth key" $ do
let m = M.singleton "a" "b"; m' = M.insert (authKey stnull) "x" m let m = HM.singleton "a" "b"; m' = HM.insert (authKey stnull) "x" m
decomposeSession authKey_ (SessionMap m') `shouldBe` decomposeSession authKey_ (SessionMap m') `shouldBe`
DecomposedSession (Just "x") DoNotForceInvalidate (SessionMap m) DecomposedSession (Just "x") DoNotForceInvalidate (SessionMap m)
@ -392,7 +392,7 @@ main = hspec $ parallel $ do
let s = mkSessionMap ((T.unpack authKey_, "foo") : data_) let s = mkSessionMap ((T.unpack authKey_, "foo") : data_)
authId = B8.pack authId_ authId = B8.pack authId_
in recomposeSession authKey_ (Just authId) s in recomposeSession authKey_ (Just authId) s
Q.=== onSM (M.adjust (const authId) authKey_) s Q.=== onSM (HM.adjust (const authId) authKey_) s
describe "MockStorage" $ do describe "MockStorage" $ do
sto <- runIO emptyMockStorage sto <- runIO emptyMockStorage
@ -401,13 +401,13 @@ main = hspec $ parallel $ do
-- | Used to generate session maps on QuickCheck properties. -- | Used to generate session maps on QuickCheck properties.
mkSessionMap :: [(String, String)] -> SessionMap mkSessionMap :: [(String, String)] -> SessionMap
mkSessionMap = SessionMap . M.fromList . map (T.pack *** B8.pack) mkSessionMap = SessionMap . HM.fromList . map (T.pack *** B8.pack)
-- | Apply a function to a 'SessionMap'. -- | Apply a function to a 'SessionMap'.
onSM onSM
:: (M.Map T.Text B8.ByteString -> M.Map T.Text B8.ByteString) :: (HM.HashMap T.Text B8.ByteString -> HM.HashMap T.Text B8.ByteString)
-> (SessionMap -> SessionMap) -> (SessionMap -> SessionMap)
onSM f = SessionMap . f . unSessionMap onSM f = SessionMap . f . unSessionMap
@ -484,7 +484,7 @@ deriving instance Show (Decomposed sess) => Show (MockOperation sess)
-- | A mock storage used just for testing. -- | A mock storage used just for testing.
data MockStorage sess = data MockStorage sess =
MockStorage MockStorage
{ mockSessions :: I.IORef (M.Map (SessionId sess) (Session sess)) { mockSessions :: I.IORef (HM.HashMap (SessionId sess) (Session sess))
, mockOperations :: I.IORef [MockOperation sess] , mockOperations :: I.IORef [MockOperation sess]
} }
deriving (Typeable) deriving (Typeable)
@ -498,30 +498,24 @@ instance IsSessionData sess => Storage (MockStorage sess) where
-- because latter may be reordered (cf. "Memory Model" on -- because latter may be reordered (cf. "Memory Model" on
-- Data.IORef's documentation). -- Data.IORef's documentation).
addMockOperation sto (GetSession sid) addMockOperation sto (GetSession sid)
M.lookup sid <$> I.atomicModifyIORef' (mockSessions sto) (\a -> (a, a)) HM.lookup sid <$> I.atomicModifyIORef' (mockSessions sto) (\a -> (a, a))
deleteSession sto sid = do deleteSession sto sid = do
I.atomicModifyIORef' (mockSessions sto) ((, ()) . M.delete sid) I.atomicModifyIORef' (mockSessions sto) ((, ()) . HM.delete sid)
addMockOperation sto (DeleteSession sid) addMockOperation sto (DeleteSession sid)
deleteAllSessionsOfAuthId sto authId = do deleteAllSessionsOfAuthId sto authId = do
I.atomicModifyIORef' (mockSessions sto) ((, ()) . M.filter (\s -> sessionAuthId s /= Just authId)) I.atomicModifyIORef' (mockSessions sto) ((, ()) . HM.filter (\s -> sessionAuthId s /= Just authId))
addMockOperation sto (DeleteAllSessionsOfAuthId authId) addMockOperation sto (DeleteAllSessionsOfAuthId authId)
insertSession sto session = do insertSession sto session = do
join $ I.atomicModifyIORef' (mockSessions sto) $ \oldMap -> join $ I.atomicModifyIORef' (mockSessions sto) $ \oldMap ->
let (moldVal, newMap) = case HM.lookup (sessionKey session) oldMap of
M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap Just oldVal -> (oldMap, mockThrow $ SessionAlreadyExists oldVal session)
in maybe Nothing -> (HM.insert (sessionKey session) session oldMap, return ())
(newMap, return ())
(\oldVal -> (oldMap, mockThrow $ SessionAlreadyExists oldVal session))
moldVal
addMockOperation sto (InsertSession session) addMockOperation sto (InsertSession session)
replaceSession sto session = do replaceSession sto session = do
join $ I.atomicModifyIORef' (mockSessions sto) $ \oldMap -> join $ I.atomicModifyIORef' (mockSessions sto) $ \oldMap ->
let (moldVal, newMap) = case HM.lookup (sessionKey session) oldMap of
M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap Just _ -> (HM.insert (sessionKey session) session oldMap, return ())
in maybe Nothing -> (oldMap, mockThrow $ SessionDoesNotExist session)
(oldMap, mockThrow $ SessionDoesNotExist session)
(const (newMap, return ()))
moldVal
addMockOperation sto (ReplaceSession session) addMockOperation sto (ReplaceSession session)
@ -537,7 +531,7 @@ mockThrow = E.throwIO
emptyMockStorage :: IO (MockStorage sess) emptyMockStorage :: IO (MockStorage sess)
emptyMockStorage = emptyMockStorage =
MockStorage MockStorage
<$> I.newIORef M.empty <$> I.newIORef HM.empty
<*> I.newIORef [] <*> I.newIORef []
@ -545,7 +539,7 @@ emptyMockStorage =
prepareMockStorage :: [Session sess] -> IO (MockStorage sess) prepareMockStorage :: [Session sess] -> IO (MockStorage sess)
prepareMockStorage sessions = do prepareMockStorage sessions = do
sto <- emptyMockStorage sto <- emptyMockStorage
I.writeIORef (mockSessions sto) (M.fromList [(sessionKey s, s) | s <- sessions]) I.writeIORef (mockSessions sto) (HM.fromList [(sessionKey s, s) | s <- sessions])
return sto return sto