{-# LANGUAGE OverloadedRecordDot #-} module AuthCode ( State (..) , AuthState , genUnencryptedCode , verify ) where import Data.Map.Strict (Map) import Data.Maybe (isJust) import Data.Time.Clock import qualified Data.Map.Strict as M import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.STM.TVar import Control.Monad (void, (>=>)) import Control.Monad.STM newtype State = State { activeCodes :: Map String (String, UTCTime) } deriving Show -- ^ maps auth codes to (client ID, expiration time) type AuthState = TVar State genUnencryptedCode :: String -> String -> NominalDiffTime -> AuthState -> IO (Maybe String) genUnencryptedCode client url expiration state = do now <- getCurrentTime let expiresAt = expiration `addUTCTime` now simpleCode = filter (/= ' ') $ client <> url <> show now <> show expiresAt success <- atomically . stateTVar state $ \s -> let mEntry = M.lookup simpleCode s.activeCodes in if isJust mEntry then (False, s) else (True, s{ activeCodes = M.insert simpleCode (client, expiresAt) s.activeCodes }) if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing expire :: String -> NominalDiffTime -> AuthState -> IO () expire code time state = void . forkIO $ do threadDelay $ fromEnum time atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } verify :: String -> String -> AuthState -> IO Bool verify code clientID state = do now <- getCurrentTime mData <- atomically $ do result <- (readTVar >=> return . M.lookup code . activeCodes) state modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } return result return $ case mData of Just (clientID', _) -> clientID == clientID' _ -> False