diff --git a/oauth2-mock-server.cabal b/oauth2-mock-server.cabal index 4d754df..83ad5af 100644 --- a/oauth2-mock-server.cabal +++ b/oauth2-mock-server.cabal @@ -19,6 +19,7 @@ extra-source-files: library exposed-modules: + AuthCode Server User other-modules: @@ -36,7 +37,10 @@ library , servant , servant-client , servant-server + , stm , text + , time + , transformers , warp default-language: Haskell2010 @@ -58,7 +62,10 @@ executable oauth2-mock-server-exe , servant , servant-client , servant-server + , stm , text + , time + , transformers , warp default-language: Haskell2010 @@ -81,6 +88,9 @@ test-suite oauth2-mock-server-test , servant , servant-client , servant-server + , stm , text + , time + , transformers , warp default-language: Haskell2010 diff --git a/package.yaml b/package.yaml index 94904fb..47736b4 100644 --- a/package.yaml +++ b/package.yaml @@ -29,6 +29,9 @@ dependencies: - aeson - text - containers +- stm +- time +- transformers ghc-options: - -Wall diff --git a/src/AuthCode.hs b/src/AuthCode.hs new file mode 100644 index 0000000..cf373a6 --- /dev/null +++ b/src/AuthCode.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE OverloadedRecordDot #-} + +module AuthCode +( State (..) +, AuthState +, genUnencryptedCode +) where + +import Data.Map.Strict (Map) +import Data.Maybe (isJust) +import Data.Time.Clock + +import qualified Data.Map.Strict as M + +import Control.Concurrent (forkIO, threadDelay) +import Control.Concurrent.STM.TVar +import Control.Monad (void) +import Control.Monad.STM + + + + +newtype State = State { activeCodes :: Map String (String, UTCTime) } deriving Show -- ^ maps auth codes to (client ID, expiration time) + +type AuthState = TVar State + +genUnencryptedCode :: String + -> String + -> NominalDiffTime + -> AuthState + -> IO (Maybe String) +genUnencryptedCode client url expiration state = do + now <- getCurrentTime + let + expiresAt = expiration `addUTCTime` now + simpleCode = client <> url <> show now <> show expiresAt + success <- atomically . stateTVar state $ \s -> + let mEntry = M.lookup simpleCode s.activeCodes + in + if isJust mEntry + then (False, s) + else (True, s{ activeCodes = M.insert simpleCode (client, expiresAt) s.activeCodes }) + if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing + + +expire :: String -> NominalDiffTime -> AuthState -> IO () +expire code time state = void . forkIO $ do + threadDelay $ fromEnum time + atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } + diff --git a/src/Server.hs b/src/Server.hs index 4cb312f..0a477a9 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -6,12 +6,15 @@ module Server , runMockServer' ) where +import AuthCode import User import Control.Concurrent +import Control.Concurrent.STM.TVar (newTVarIO) import Control.Exception (bracket) import Control.Monad (unless) import Control.Monad.IO.Class +import Control.Monad.Trans.Reader import Data.Aeson import Data.List (find) @@ -68,7 +71,14 @@ type Auth user userData = BasicAuth "login" user type Token = "token" :> Post '[JSON] Text -- TODO post jwt token -- type Insert = "insert" :> Post '[JSON] User -authServer :: forall user userData . UserData user userData => Server (Auth user userData) + +type AuthHandler = ReaderT AuthState Handler +type AuthServer a = ServerT a AuthHandler + +toHandler :: AuthState -> AuthHandler a -> Handler a +toHandler s h = runReaderT h s + +authServer :: forall user userData . UserData user userData => AuthServer (Auth user userData) authServer = handleAuth where handleAuth :: user @@ -76,19 +86,20 @@ authServer = handleAuth -> QClient -> QResType -> QRedirect - -> Handler userData + -> AuthHandler userData handleAuth u scopes client responseType url = do - unless (pack client `elem` trustedClients) . -- TODO fetch trusted clients from db + unless (pack client `elem` trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client throwError $ err404 { errBody = "Not a trusted client."} let scopes' = map (readScope @user @userData) $ words scopes uData = mconcat $ map (userScope @user @userData u) scopes' responseType' = read @ResponseType responseType - - liftIO (putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')) + mAuthCode <- asks (genUnencryptedCode client url 600) >>= liftIO + liftIO $ print mAuthCode + liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes') return uData -exampleAuthServer :: Server (Auth User (Map.Map Text Text)) +exampleAuthServer :: AuthServer (Auth User (Map.Map Text Text)) exampleAuthServer = authServer authAPI :: Proxy (Auth User (Map.Map Text Text)) @@ -97,9 +108,11 @@ authAPI = Proxy -- insecureOAuthMock :: Application -- insecureOAuthMock = authAPI `serve` exampleAuthServer -insecureOAuthMock' :: [User] -> Application -insecureOAuthMock' testUsers = serveWithContext authAPI c exampleAuthServer - where c = authenticate testUsers :. EmptyContext +insecureOAuthMock' :: [User] -> AuthState -> Application +insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler s) exampleAuthServer + where + c = authenticate testUsers :. EmptyContext + p = Proxy :: Proxy '[BasicAuthCheck User] authenticate :: [User] -> BasicAuthCheck User authenticate users = BasicAuthCheck $ \authData -> do @@ -113,12 +126,15 @@ frontend :: BasicAuthData -> ClientM (Map.Map Text Text) frontend ba = client authAPI ba "[ID]" "42" "code" "" runMockServer :: Int -> IO () -runMockServer port = run port $ insecureOAuthMock' testUsers +runMockServer port = do + state <- newTVarIO $ State { activeCodes = Map.empty } + run port $ insecureOAuthMock' testUsers state runMockServer' :: Int -> IO () runMockServer' port = do mgr <- newManager defaultManagerSettings - bracket (forkIO . run port $ insecureOAuthMock' testUsers) killThread $ \_ -> + state <- newTVarIO $ State { activeCodes = Map.empty } + bracket (forkIO . run port $ insecureOAuthMock' testUsers state) killThread $ \_ -> runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port "")) >>= print