Merge branch 'refresh-tokens' into 'main'

Refresh Tokens

See merge request mosbach/oauth2-mock-server!3
This commit is contained in:
Nora Mosbach 2024-01-29 00:53:51 +00:00
commit 31f99eef37
10 changed files with 245 additions and 98 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
.stack-work/ .stack-work/
*~ *~
database/* database/*
result

View File

@ -18,7 +18,7 @@
MultiParamTypeClasses, MultiParamTypeClasses,
RecordWildCards #-} RecordWildCards #-}
module UniWorX (User(..), initDB, testUsers) where module UniWorX (User(..), initDB) where
import User import User
@ -33,6 +33,7 @@ import Conduit (ResourceT)
import Data.Map (Map(..)) import Data.Map (Map(..))
import Data.String (IsString(..)) import Data.String (IsString(..))
import Data.Text (Text(..)) import Data.Text (Text(..))
import Data.Yaml (decodeFileThrow, FromJSON(..), Value(..), (.:), (.:?))
import qualified Data.Map as M import qualified Data.Map as M
import qualified Data.Text as T import qualified Data.Text as T
@ -45,18 +46,49 @@ import System.Environment (lookupEnv)
share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase| share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
User User
name Text firstName Text
email Text surname Text
email Text
matricNumber Text Maybe
title Text Maybe
sex Text Maybe
birthday Text Maybe
telephone Text Maybe
mobile Text Maybe
compPersNumber Text Maybe
compDepartment Text Maybe
postAddress Text Maybe
deriving Eq Show deriving Eq Show
|] |]
testUsers :: [User] -- TODO move to db instance FromJSON User where
testUsers = parseJSON (Object o) = User
[ User "Fallback User" "foo@bar.com" <$> o .: "userFirstName"
, User "Tina Tester" "tester@campus.lmu.de" <*> o .: "userSurname"
, User "Max Muster" "m@m.mm" ] <*> o .: "userEmail"
<*> o .:? "userMatrikelnummer"
<*> o .:? "userTitle"
<*> o .:? "userSex"
<*> o .:? "userBirthday"
<*> o .:? "userTelephone"
<*> o .:? "userMobile"
<*> o .:? "userCompanyPersonalNumber"
<*> o .:? "userCompanyDepartment"
<*> o .:? "userPostAddress"
parseJSON _ = error "Oauth2 Mock Server: invalid test user format"
runDB :: ReaderT SqlBackend (NoLoggingT (ResourceT IO)) a -> IO a data TestUserSpec = TestUsers
{ specialUsers :: [Map Text User]
, randomUsers :: Map Text [Maybe Text]
} deriving (Show)
instance FromJSON TestUserSpec where
parseJSON (Object o) = TestUsers <$> o .: "special-users" <*> o .: "random-users"
parseJSON _ = error "Oauth2 Mock Server: invalid test user format"
type DB = ReaderT SqlBackend (NoLoggingT (ResourceT IO))
runDB :: DB a -> IO a
runDB action = do runDB action = do
Just port <- lookupEnv "OAUTH2_DB_PORT" -- >>= \p -> return $ p <|> Just "9444" Just port <- lookupEnv "OAUTH2_DB_PORT" -- >>= \p -> return $ p <|> Just "9444"
Just host <- lookupEnv "OAUTH2_PGHOST" Just host <- lookupEnv "OAUTH2_PGHOST"
@ -64,19 +96,39 @@ runDB action = do
runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ flip runSqlPersistMPool pool action runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ flip runSqlPersistMPool pool action
initDB :: IO () initDB :: IO ()
initDB = runDB $ do initDB = do
runMigration migrateAll Just testUserFile <- lookupEnv "OAUTH2_TEST_USERS"
forM_ testUsers $ void . insert runDB $ do
runMigration migrateAll
testUsers <- decodeFileThrow @DB @TestUserSpec testUserFile
liftIO . putStrLn $ "the test users:\n" ++ show testUsers
let users = M.elems . mconcat $ specialUsers testUsers
forM_ users $ void . insert
instance UserData (Entity User) (Map Text Text) where instance UserData (Entity User) (Map Text Text) where
data Scope (Entity User) = ID | Profile deriving (Show, Read, Eq) data Scope (Entity User) = ID | Profile deriving (Show, Read, Eq)
readScope = read readScope = read
showScope = show showScope = show
userScope (Entity _ User{..}) ID = M.singleton "id" userEmail userScope (Entity _ User{..}) ID = M.singleton "id" userEmail
userScope (Entity _ User{..}) Profile = M.fromList [("name", userName), ("email", userEmail)] userScope (Entity _ User{..}) Profile = M.fromList [(key, val) | (key, Just val) <-
[ ("firstName", Just userFirstName)
, ("surname", Just userSurname)
, ("email", Just userEmail)
, ("matriculationNumber", userMatricNumber)
, ("title", userTitle)
, ("sex", userSex)
, ("birthday", userBirthday)
, ("telephone", userTelephone)
, ("mobile", userMobile)
, ("companyPersonalNumber", userCompPersNumber)
, ("companyDepartment", userCompDepartment)
, ("postAddress", userPostAddress)
]]
lookupUser email _ = runDB $ do lookupUser email _ = runDB $ do
user <- selectList [UserEmail ==. email] [] user <- selectList [UserEmail ==. email] []
case user of case user of
[entity] -> return $ Just entity [entity] -> return $ Just entity
[] -> return Nothing [] -> return Nothing
_ -> error "Ambiguous User." _ -> error "Oauth2 Mock Server: Ambiguous User."

View File

@ -29,16 +29,56 @@
with haskell.packages."ghc927"; [ ghc haskell-language-server ] with haskell.packages."ghc927"; [ ghc haskell-language-server ]
) )
); );
libPath = pkgs.lib.makeLibraryPath buildInputs;
oms = pkgs.stdenv.mkDerivation {
inherit buildInputs;
inherit name;
pname = name;
src = ./.;
# dontUnpack = true;
buildPhase = ''
HOME=$out
LD_LIBRARY_PATH=${libPath}
mkdir -p $HOME/.stack
stack build --verbose
'';
installPhase = ''
mkdir -p $out/bin
mv .stack-work/install/${system}/*/*/bin/${name}-exe $out/bin/${name}
echo "moved"
'';
};
mkDB = builtins.readFile ./mkDB.sh;
killDB = builtins.readFile ./killDB.sh;
in { in {
packages.${system}.${name} = nixpkgs.legacyPackages.${system}.${name}; packages.${system} = {
packages.${system}.default = self.packages.${system}.${name}; ${name} = oms; # nixpkgs.legacyPackages.${system}.${name};
mkOauth2DB = pkgs.writeScriptBin "mkOauth2DB" ''
#!${pkgs.zsh}/bin/zsh -e
${mkDB}
'';
killOauth2DB = pkgs.writeScriptBin "killOauth2DB" ''
#!${pkgs.zsh}/bin/zsh -e
${killDB}
'';
default = self.packages.${system}.${name};
};
devShells.${system}.default = pkgs.mkShell { devShells.${system}.default = pkgs.mkShell {
buildInputs = buildInputs; buildInputs = buildInputs ++ (with self.packages.${system}; [mkOauth2DB killOauth2DB]);
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs; LD_LIBRARY_PATH = libPath;
shellHook = builtins.readFile ./mkDB.sh; OAUTH2_HBA = ./hba_file;
OAUTH2_DB_SCHEMA = ./schema.sql;
OAUTH2_TEST_USERS = ./users.yaml;
OAUTH2_SERVER_PORT = 9443;
OAUTH2_DB_PORT = 9444;
shellHook = ''
${mkDB}
zsh
${killDB}
'';
};
}; };
}; }
}

8
killDB.sh Executable file
View File

@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: 2024 UniWorX Systems
# SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
#
# SPDX-License-Identifier: AGPL-3.0-or-later
pg_ctl stop -D "${OAUTH2_PGDIR}"
rm -rvf "${OAUTH2_PGDIR}" "${OAUTH2_PGHOST}" "${OAUTH2_PGLOG}"

14
mkDB.sh
View File

@ -3,10 +3,9 @@
# #
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
export OAUTH2_SERVER_PORT=9443 [[ -z "${OAUTH2_HBA}" || -z "${OAUTH2_DB_SCHEMA}" ]] && echo "oauth2: missing env vars for hba and/or schema" && exit 1
export OAUTH2_DB_PORT=9444
tmpdir=./database tmpdir=${XDG_RUNTIME_DIR}/.oauth2-db
if [ ! -d "${tmpdir}" ]; then if [ ! -d "${tmpdir}" ]; then
mkdir ${tmpdir} mkdir ${tmpdir}
@ -19,15 +18,12 @@ pgSockDir=$(mktemp -d --tmpdir="${absdir}" postgresql.sock.XXXXXX)
pgLogFile=$(mktemp --tmpdir="${absdir}" postgresql.XXXXXX.log) pgLogFile=$(mktemp --tmpdir="${absdir}" postgresql.XXXXXX.log)
initdb --no-locale -D "${pgDir}" initdb --no-locale -D "${pgDir}"
pg_ctl start -D "${pgDir}" -l "${pgLogFile}" -w -o "-k ${pgSockDir} -c listen_addresses='::' -c hba_file='hba_file' -p ${OAUTH2_DB_PORT} -h localhost -c unix_socket_permissions=0700 -c max_connections=10 -c session_preload_libraries=auto_explain -c auto_explain.log_min_duration=100ms" pg_ctl start -D "${pgDir}" -l "${pgLogFile}" -w -o "-k ${pgSockDir} -c listen_addresses='::' -c hba_file='${OAUTH2_HBA}' -p ${OAUTH2_DB_PORT} -h localhost -c unix_socket_permissions=0700 -c max_connections=10 -c session_preload_libraries=auto_explain -c auto_explain.log_min_duration=100ms"
psql -h "${pgSockDir}" -p ${OAUTH2_DB_PORT} -f ./schema.sql postgres psql -h "${pgSockDir}" -p ${OAUTH2_DB_PORT} -f "${OAUTH2_DB_SCHEMA}" postgres
printf "Postgres logfile is %s\nPostgres socket directory is %s\n" "${pgLogFile}" "${pgSockDir}" printf "Postgres logfile is %s\nPostgres socket directory is %s\n" "${pgLogFile}" "${pgSockDir}"
export OAUTH2_PGHOST="${pgSockDir}" export OAUTH2_PGHOST="${pgSockDir}"
export OAUTH2_PGLOG="${pgLogFile}" export OAUTH2_PGLOG="${pgLogFile}"
export OAUTH2_PGDIR="${pgDir}"
zsh
pg_ctl stop -D "${pgDir}"
rm -rvf "${pgDir}" "${pgSockDir}" "${pgLogFile}"

View File

@ -91,6 +91,7 @@ executable oauth2-mock-server-exe
, transformers , transformers
, uuid , uuid
, warp , warp
, yaml
default-language: Haskell2010 default-language: Haskell2010
test-suite oauth2-mock-server-test test-suite oauth2-mock-server-test

View File

@ -75,6 +75,7 @@ executables:
- monad-logger - monad-logger
- conduit - conduit
- mtl - mtl
- yaml
tests: tests:
oauth2-mock-server-test: oauth2-mock-server-test:

View File

@ -3,27 +3,36 @@
-- --
-- SPDX-License-Identifier: AGPL-3.0-or-later -- SPDX-License-Identifier: AGPL-3.0-or-later
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-} {-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-}
module AuthCode module AuthCode
( State(..) ( State(..)
, AuthState , AuthState
, AuthRequest(..) , AuthRequest(..)
, JWT(..) , JWT(..)
, JWTWrapper(..)
, genUnencryptedCode , genUnencryptedCode
, verify , verify
, mkToken
, decodeToken
, renewToken
) where ) where
import User import User
import Data.Aeson import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.Either (fromRight)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Maybe (isJust, fromMaybe) import Data.Maybe (isJust, fromMaybe, fromJust)
import Data.Time.Clock import Data.Time.Clock
import Data.Text (pack, replace, Text) import Data.Text (pack, replace, Text, stripPrefix)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64 import Data.Text.Encoding.Base64
import Data.UUID import Data.UUID
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as M import qualified Data.Map.Strict as M
import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent (forkIO, threadDelay)
@ -31,7 +40,12 @@ import Control.Concurrent.STM.TVar
import Control.Monad (void, (>=>)) import Control.Monad (void, (>=>))
import Control.Monad.STM import Control.Monad.STM
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (Jwk(..)) import Jose.Jwk (Jwk(..))
import Jose.Jwt hiding (decode, encode)
import Servant.API (FromHttpApiData(..))
data JWT = JWT data JWT = JWT
@ -47,6 +61,31 @@ instance FromJSON JWT where
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti" parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
data JWTWrapper = JWTW
{ acessToken :: String
, expiresIn :: NominalDiffTime
, refreshToken :: Maybe String
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW a e r) = object
[ "access_token" .= a
, "token_type" .= ("JWT" :: Text)
, "expires_in" .= fromEnum e
, "refresh_token" .= r ]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
<*> o .:? "refresh_token"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
data AuthRequest user = AuthRequest data AuthRequest user = AuthRequest
{ client :: String { client :: String
, codeExpiration :: NominalDiffTime , codeExpiration :: NominalDiffTime
@ -81,12 +120,11 @@ genUnencryptedCode req url state = do
then (False, s) then (False, s)
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes }) else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
where
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
expire :: Text -> NominalDiffTime -> AuthState user -> IO () expire code time state = void . forkIO $ do
expire code time state = void . forkIO $ do threadDelay $ fromEnum time
threadDelay $ fromEnum time atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user])) verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
@ -99,3 +137,45 @@ verify code mClientID state = do
return $ case mData of return $ case mData of
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
_ -> Nothing _ -> Nothing
mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetimeAT = 3600 :: NominalDiffTime -- TODO make configurable
lifetimeRT = nominalDay -- TODO make configurable
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
case encodedAT >> encodedRT of
Right _ -> do
let Jwt aToken = fromRight undefined encodedAT
Jwt rToken = fromRight undefined encodedRT
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
Left e -> error $ show e
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper)
renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of
Just t -> decodeToken t state >>= \case
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
now <- getCurrentTime
if now <= expiration jwt then return Nothing else do
mUser <- atomically . stateTVar state $ \s ->
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
in (key, s { activeTokens = tokens })
case mUser of
Just (u, scopes) -> Just <$> mkToken u scopes state
Nothing -> return Nothing
Left _ -> return Nothing
Nothing -> return Nothing

View File

@ -22,28 +22,24 @@ import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
import Control.Exception (bracket) import Control.Exception (bracket)
import Control.Monad (unless, (>=>)) import Control.Monad (unless, (>=>))
import Control.Monad.IO.Class import Control.Monad.IO.Class
import Control.Monad.Trans.Error (Error(..))
import Control.Monad.Trans.Reader import Control.Monad.Trans.Reader
import Data.Aeson import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict) import Data.ByteString (fromStrict)
import Data.List (find, elemIndex) import Data.List (find, elemIndex)
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing) import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
import Data.String (IsString (..)) import Data.String (IsString (..))
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words) import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
import qualified Data.Text as T import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64 import Data.Text.Encoding.Base64
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime) import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import GHC.Read (readPrec, lexP) import GHC.Read (readPrec, lexP)
import Jose.Jwa import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId)
import Jose.Jwe
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
import Jose.Jwt hiding (decode, encode) import Jose.Jwt hiding (decode, encode)
import Network.HTTP.Client (newManager, defaultManagerSettings) import Network.HTTP.Client (newManager, defaultManagerSettings)
@ -188,9 +184,10 @@ codeServer = handleCreds
---- Token Endpoint ---- ---- Token Endpoint ----
---------------------- ----------------------
newtype ACode = ACode String deriving (Show)
data ClientData = ClientData --TODO support other flows data ClientData = ClientData --TODO support other flows
{ authCode :: String { authID :: Either ACode JWTWrapper
, clientID :: Maybe String , clientID :: Maybe String
, clientSecret :: Maybe String , clientSecret :: Maybe String
, redirect :: Maybe String , redirect :: Maybe String
@ -203,29 +200,15 @@ instance FromHttpApiData AuthFlow where
instance FromForm ClientData where instance FromForm ClientData where
fromForm f = ClientData fromForm f = ClientData
<$> ((parseUnique @AuthFlow "grant_type" f) *> parseUnique "code" f) <$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f)
<|> (Right <$> parseUnique "refresh_token" f)))
<*> parseMaybe "client_id" f <*> parseMaybe "client_id" f
<*> parseMaybe "client_secret" f <*> parseMaybe "client_secret" f
<*> parseMaybe "redirect_uri" f <*> parseMaybe "redirect_uri" f
instance Error Text where
strMsg = pack
data JWTWrapper = JWTW
{ token :: String
, expiresIn :: NominalDiffTime
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
type Token = "token" type Token = "token"
:> ReqBody '[FormUrlEncoded] ClientData :> ReqBody '[FormUrlEncoded] ClientData
@ -240,30 +223,20 @@ tokenEndpoint = provideToken
unless (isNothing (clientID client >> clientSecret client) unless (isNothing (clientID client >> clientSecret client)
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) . || Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
throwError $ err500 { errBody = "Invalid client" } throwError $ err500 { errBody = "Invalid client" }
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here case authID client of
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" } Left (ACode authCode) -> do
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} mUser <- asks (verify (pack authCode) (clientID client)) >>= liftIO -- TODO verify redirect url here
let (user, scopes) = fromJust mUser unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
token <- asks (mkToken @user @userData user scopes) >>= liftIO -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
liftIO . putStrLn $ "token: " ++ show token let (user, scopes) = fromJust mUser
return token token <- asks (mkToken @user user scopes) >>= liftIO
liftIO . putStrLn $ "token: " ++ show token
return token
mkToken :: forall user userData . UserData user userData Right jwtw -> do
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper mToken <- asks (renewToken @user jwtw) >>= liftIO
mkToken u scopes state = do case mToken of
pubKey <- atomically $ readTVar state >>= return . publicKey Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token
now <- getCurrentTime Nothing -> throwError $ err500 { errBody = "Invalid refresh token" }
uuid <- nextRandom
let
lifetime = nominalDay / 24 -- TODO make configurable
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
case encoded of
Right (Jwt token) -> do
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack token) lifetime
Left e -> error $ show e
---------------------- ----------------------
@ -291,8 +264,8 @@ userEndpoint = handleUserData
handleUserData :: Text -> AuthHandler user (Maybe userData) handleUserData :: Text -> AuthHandler user (Maybe userData)
handleUserData jwtw = do handleUserData jwtw = do
let mToken = stripPrefix "Bearer " jwtw let mToken = stripPrefix "Bearer " jwtw
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"} unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" }
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO
liftIO $ putStrLn "decoded token:" >> print token liftIO $ putStrLn "decoded token:" >> print token
case token of case token of
Left e -> throwError $ err500 { errBody = fromString $ show e } Left e -> throwError $ err500 { errBody = fromString $ show e }
@ -306,11 +279,6 @@ userEndpoint = handleUserData
Nothing -> throwError $ err500 { errBody = "Unknown token" } Nothing -> throwError $ err500 { errBody = "Unknown token" }
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData) userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
userListEndpoint = handleUserData userListEndpoint = handleUserData
where where

0
users.yaml Normal file
View File