From 751f3a463abd76163b870b97031e35af372e2890 Mon Sep 17 00:00:00 2001 From: David Mosbach Date: Tue, 9 Jan 2024 03:57:03 +0100 Subject: [PATCH] generate encrypted tokens --- oauth2-mock-server.cabal | 3 ++ package.yaml | 1 + src/AuthCode.hs | 8 ++++- src/Server.hs | 66 +++++++++++++++++++++++++++++++--------- 4 files changed, 63 insertions(+), 15 deletions(-) diff --git a/oauth2-mock-server.cabal b/oauth2-mock-server.cabal index 25276d8..eeb6d46 100644 --- a/oauth2-mock-server.cabal +++ b/oauth2-mock-server.cabal @@ -35,6 +35,7 @@ library , bytestring , containers , http-client + , jose-jwt , servant , servant-client , servant-server @@ -60,6 +61,7 @@ executable oauth2-mock-server-exe , bytestring , containers , http-client + , jose-jwt , oauth2-mock-server , servant , servant-client @@ -87,6 +89,7 @@ test-suite oauth2-mock-server-test , bytestring , containers , http-client + , jose-jwt , oauth2-mock-server , servant , servant-client diff --git a/package.yaml b/package.yaml index 4b1c903..320c558 100644 --- a/package.yaml +++ b/package.yaml @@ -33,6 +33,7 @@ dependencies: - time - transformers - bytestring +- jose-jwt ghc-options: - -Wall diff --git a/src/AuthCode.hs b/src/AuthCode.hs index 708ba4b..da1dfe1 100644 --- a/src/AuthCode.hs +++ b/src/AuthCode.hs @@ -18,10 +18,16 @@ import Control.Concurrent.STM.TVar import Control.Monad (void, (>=>)) import Control.Monad.STM +import Jose.Jwk (Jwk(..)) -newtype State = State { activeCodes :: Map String (String, UTCTime) } deriving Show -- ^ maps auth codes to (client ID, expiration time) + +data State = State + { activeCodes :: Map String (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time) + , publicKey :: Jwk + , privateKey :: Jwk + } deriving Show type AuthState = TVar State diff --git a/src/Server.hs b/src/Server.hs index 4a8ecf4..408b76a 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE DataKinds, TypeOperators, OverloadedStrings, ScopedTypeVariables, TypeApplications #-} +{-# LANGUAGE DataKinds, TypeOperators, OverloadedStrings, ScopedTypeVariables, TypeApplications, RecordWildCards #-} module Server ( insecureOAuthMock' @@ -11,23 +11,30 @@ import User import Control.Applicative ((<|>)) import Control.Concurrent -import Control.Concurrent.STM.TVar (newTVarIO) +import Control.Concurrent.STM (atomically) +import Control.Concurrent.STM.TVar (newTVarIO, readTVar) import Control.Exception (bracket) import Control.Monad (unless) import Control.Monad.IO.Class import Control.Monad.Trans.Reader import Data.Aeson -import Data.ByteString (ByteString (..)) +import Data.ByteString (ByteString (..), toStrict) import Data.List (find, elemIndex) import Data.Maybe (fromMaybe, isJust) import Data.String (IsString (..)) import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words) import Data.Text.Encoding (decodeUtf8) -import Data.Time.Clock (NominalDiffTime (..), nominalDay) +import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime) +import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as Map +import Jose.Jwa +import Jose.Jwe +import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId) +import Jose.Jwt hiding (decode, encode) + import Network.HTTP.Client (newManager, defaultManagerSettings) import Network.Wai.Handler.Warp @@ -148,17 +155,23 @@ frontend ba = client authAPI ba "[ID]" "42" "code" "" runMockServer :: Int -> IO () runMockServer port = do - state <- newTVarIO $ State { activeCodes = Map.empty } + state <- mkState run port $ insecureOAuthMock' testUsers state runMockServer' :: Int -> IO () runMockServer' port = do mgr <- newManager defaultManagerSettings - state <- newTVarIO $ State { activeCodes = Map.empty } + state <- mkState bracket (forkIO . run port $ insecureOAuthMock' testUsers state) killThread $ \_ -> runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port "")) >>= print +mkState :: IO AuthState +mkState = do + (publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing + let activeCodes = Map.empty + newTVarIO State{..} + ------ ------ Token @@ -200,26 +213,51 @@ instance FromJSON GrantType where | otherwise = error $ show s ++ " grant type not supported yet" data JWT = JWT - { token :: Text -- TODO should be JWT - , tokenType :: Text -- TODO enum - , expiration :: NominalDiffTime - } + { issuer :: Text + , expiration :: UTCTime + } deriving (Show, Eq) + +instance ToJSON JWT where + toJSON (JWT i e) = object ["iss" .= i, "exp" .= e] + +data JWTWrapper = JWTW + { token :: String + , expiresIn :: NominalDiffTime + } deriving (Show) + +instance ToJSON JWTWrapper where + toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e] type Token = "token" :> ReqBody '[JSON] ClientData - :> Post '[JSON] JWT + :> Post '[JSON] JWTWrapper tokenEndpoint :: AuthServer Token tokenEndpoint = provideToken where - provideToken :: ClientData -> AuthHandler JWT + provideToken :: ClientData -> AuthHandler JWTWrapper provideToken client = case (grantType client) of AuthCode -> do - --TODO validate everything unless (Client (pack $ clientID client) (pack $ clientSecret client) `elem` trustedClients) . throwError $ err500 { errBody = "Invalid client" } valid <- asks (verify (grant client) (clientID client)) >>= liftIO unless valid . throwError $ err500 { errBody = "Invalid authorisation code" } - return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} + -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} + token <- asks mkToken >>= liftIO + return token x -> error $ show x ++ " not supported yet" + +mkToken :: AuthState -> IO JWTWrapper +mkToken state = do + privateKey <- atomically $ readTVar state >>= return . privateKey + now <- getCurrentTime + let + lifetime = nominalDay / 4 -- TODO make configurable + jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) + encoded <- jwkEncode RSA_OAEP_256 A128GCM privateKey (Nested . Jwt . toStrict $ encode jwt) + case encoded of + Right (Jwt token) -> return $ JWTW (BS.unpack token) lifetime + Left e -> error $ show e + +