Merge branch 'refresh-tokens' into 'main'
Refresh Tokens See merge request mosbach/oauth2-mock-server!3
This commit is contained in:
commit
31f99eef37
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
.stack-work/
|
.stack-work/
|
||||||
*~
|
*~
|
||||||
database/*
|
database/*
|
||||||
|
result
|
||||||
|
|||||||
@ -18,7 +18,7 @@
|
|||||||
MultiParamTypeClasses,
|
MultiParamTypeClasses,
|
||||||
RecordWildCards #-}
|
RecordWildCards #-}
|
||||||
|
|
||||||
module UniWorX (User(..), initDB, testUsers) where
|
module UniWorX (User(..), initDB) where
|
||||||
|
|
||||||
import User
|
import User
|
||||||
|
|
||||||
@ -33,6 +33,7 @@ import Conduit (ResourceT)
|
|||||||
import Data.Map (Map(..))
|
import Data.Map (Map(..))
|
||||||
import Data.String (IsString(..))
|
import Data.String (IsString(..))
|
||||||
import Data.Text (Text(..))
|
import Data.Text (Text(..))
|
||||||
|
import Data.Yaml (decodeFileThrow, FromJSON(..), Value(..), (.:), (.:?))
|
||||||
import qualified Data.Map as M
|
import qualified Data.Map as M
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
|
|
||||||
@ -45,18 +46,49 @@ import System.Environment (lookupEnv)
|
|||||||
|
|
||||||
share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
|
share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
|
||||||
User
|
User
|
||||||
name Text
|
firstName Text
|
||||||
email Text
|
surname Text
|
||||||
|
email Text
|
||||||
|
matricNumber Text Maybe
|
||||||
|
title Text Maybe
|
||||||
|
sex Text Maybe
|
||||||
|
birthday Text Maybe
|
||||||
|
telephone Text Maybe
|
||||||
|
mobile Text Maybe
|
||||||
|
compPersNumber Text Maybe
|
||||||
|
compDepartment Text Maybe
|
||||||
|
postAddress Text Maybe
|
||||||
deriving Eq Show
|
deriving Eq Show
|
||||||
|]
|
|]
|
||||||
|
|
||||||
testUsers :: [User] -- TODO move to db
|
instance FromJSON User where
|
||||||
testUsers =
|
parseJSON (Object o) = User
|
||||||
[ User "Fallback User" "foo@bar.com"
|
<$> o .: "userFirstName"
|
||||||
, User "Tina Tester" "tester@campus.lmu.de"
|
<*> o .: "userSurname"
|
||||||
, User "Max Muster" "m@m.mm" ]
|
<*> o .: "userEmail"
|
||||||
|
<*> o .:? "userMatrikelnummer"
|
||||||
|
<*> o .:? "userTitle"
|
||||||
|
<*> o .:? "userSex"
|
||||||
|
<*> o .:? "userBirthday"
|
||||||
|
<*> o .:? "userTelephone"
|
||||||
|
<*> o .:? "userMobile"
|
||||||
|
<*> o .:? "userCompanyPersonalNumber"
|
||||||
|
<*> o .:? "userCompanyDepartment"
|
||||||
|
<*> o .:? "userPostAddress"
|
||||||
|
parseJSON _ = error "Oauth2 Mock Server: invalid test user format"
|
||||||
|
|
||||||
runDB :: ReaderT SqlBackend (NoLoggingT (ResourceT IO)) a -> IO a
|
data TestUserSpec = TestUsers
|
||||||
|
{ specialUsers :: [Map Text User]
|
||||||
|
, randomUsers :: Map Text [Maybe Text]
|
||||||
|
} deriving (Show)
|
||||||
|
|
||||||
|
instance FromJSON TestUserSpec where
|
||||||
|
parseJSON (Object o) = TestUsers <$> o .: "special-users" <*> o .: "random-users"
|
||||||
|
parseJSON _ = error "Oauth2 Mock Server: invalid test user format"
|
||||||
|
|
||||||
|
type DB = ReaderT SqlBackend (NoLoggingT (ResourceT IO))
|
||||||
|
|
||||||
|
runDB :: DB a -> IO a
|
||||||
runDB action = do
|
runDB action = do
|
||||||
Just port <- lookupEnv "OAUTH2_DB_PORT" -- >>= \p -> return $ p <|> Just "9444"
|
Just port <- lookupEnv "OAUTH2_DB_PORT" -- >>= \p -> return $ p <|> Just "9444"
|
||||||
Just host <- lookupEnv "OAUTH2_PGHOST"
|
Just host <- lookupEnv "OAUTH2_PGHOST"
|
||||||
@ -64,19 +96,39 @@ runDB action = do
|
|||||||
runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ flip runSqlPersistMPool pool action
|
runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ flip runSqlPersistMPool pool action
|
||||||
|
|
||||||
initDB :: IO ()
|
initDB :: IO ()
|
||||||
initDB = runDB $ do
|
initDB = do
|
||||||
runMigration migrateAll
|
Just testUserFile <- lookupEnv "OAUTH2_TEST_USERS"
|
||||||
forM_ testUsers $ void . insert
|
runDB $ do
|
||||||
|
runMigration migrateAll
|
||||||
|
testUsers <- decodeFileThrow @DB @TestUserSpec testUserFile
|
||||||
|
liftIO . putStrLn $ "the test users:\n" ++ show testUsers
|
||||||
|
let users = M.elems . mconcat $ specialUsers testUsers
|
||||||
|
forM_ users $ void . insert
|
||||||
|
|
||||||
|
|
||||||
instance UserData (Entity User) (Map Text Text) where
|
instance UserData (Entity User) (Map Text Text) where
|
||||||
data Scope (Entity User) = ID | Profile deriving (Show, Read, Eq)
|
data Scope (Entity User) = ID | Profile deriving (Show, Read, Eq)
|
||||||
readScope = read
|
readScope = read
|
||||||
showScope = show
|
showScope = show
|
||||||
userScope (Entity _ User{..}) ID = M.singleton "id" userEmail
|
userScope (Entity _ User{..}) ID = M.singleton "id" userEmail
|
||||||
userScope (Entity _ User{..}) Profile = M.fromList [("name", userName), ("email", userEmail)]
|
userScope (Entity _ User{..}) Profile = M.fromList [(key, val) | (key, Just val) <-
|
||||||
|
[ ("firstName", Just userFirstName)
|
||||||
|
, ("surname", Just userSurname)
|
||||||
|
, ("email", Just userEmail)
|
||||||
|
, ("matriculationNumber", userMatricNumber)
|
||||||
|
, ("title", userTitle)
|
||||||
|
, ("sex", userSex)
|
||||||
|
, ("birthday", userBirthday)
|
||||||
|
, ("telephone", userTelephone)
|
||||||
|
, ("mobile", userMobile)
|
||||||
|
, ("companyPersonalNumber", userCompPersNumber)
|
||||||
|
, ("companyDepartment", userCompDepartment)
|
||||||
|
, ("postAddress", userPostAddress)
|
||||||
|
]]
|
||||||
lookupUser email _ = runDB $ do
|
lookupUser email _ = runDB $ do
|
||||||
user <- selectList [UserEmail ==. email] []
|
user <- selectList [UserEmail ==. email] []
|
||||||
case user of
|
case user of
|
||||||
[entity] -> return $ Just entity
|
[entity] -> return $ Just entity
|
||||||
[] -> return Nothing
|
[] -> return Nothing
|
||||||
_ -> error "Ambiguous User."
|
_ -> error "Oauth2 Mock Server: Ambiguous User."
|
||||||
|
|
||||||
|
|||||||
56
flake.nix
56
flake.nix
@ -29,16 +29,56 @@
|
|||||||
with haskell.packages."ghc927"; [ ghc haskell-language-server ]
|
with haskell.packages."ghc927"; [ ghc haskell-language-server ]
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
libPath = pkgs.lib.makeLibraryPath buildInputs;
|
||||||
|
oms = pkgs.stdenv.mkDerivation {
|
||||||
|
inherit buildInputs;
|
||||||
|
inherit name;
|
||||||
|
pname = name;
|
||||||
|
src = ./.;
|
||||||
|
# dontUnpack = true;
|
||||||
|
buildPhase = ''
|
||||||
|
HOME=$out
|
||||||
|
LD_LIBRARY_PATH=${libPath}
|
||||||
|
mkdir -p $HOME/.stack
|
||||||
|
stack build --verbose
|
||||||
|
'';
|
||||||
|
installPhase = ''
|
||||||
|
mkdir -p $out/bin
|
||||||
|
mv .stack-work/install/${system}/*/*/bin/${name}-exe $out/bin/${name}
|
||||||
|
echo "moved"
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
mkDB = builtins.readFile ./mkDB.sh;
|
||||||
|
killDB = builtins.readFile ./killDB.sh;
|
||||||
in {
|
in {
|
||||||
|
|
||||||
packages.${system}.${name} = nixpkgs.legacyPackages.${system}.${name};
|
packages.${system} = {
|
||||||
packages.${system}.default = self.packages.${system}.${name};
|
${name} = oms; # nixpkgs.legacyPackages.${system}.${name};
|
||||||
|
mkOauth2DB = pkgs.writeScriptBin "mkOauth2DB" ''
|
||||||
|
#!${pkgs.zsh}/bin/zsh -e
|
||||||
|
${mkDB}
|
||||||
|
'';
|
||||||
|
killOauth2DB = pkgs.writeScriptBin "killOauth2DB" ''
|
||||||
|
#!${pkgs.zsh}/bin/zsh -e
|
||||||
|
${killDB}
|
||||||
|
'';
|
||||||
|
default = self.packages.${system}.${name};
|
||||||
|
};
|
||||||
|
|
||||||
devShells.${system}.default = pkgs.mkShell {
|
devShells.${system}.default = pkgs.mkShell {
|
||||||
buildInputs = buildInputs;
|
buildInputs = buildInputs ++ (with self.packages.${system}; [mkOauth2DB killOauth2DB]);
|
||||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
|
LD_LIBRARY_PATH = libPath;
|
||||||
shellHook = builtins.readFile ./mkDB.sh;
|
OAUTH2_HBA = ./hba_file;
|
||||||
|
OAUTH2_DB_SCHEMA = ./schema.sql;
|
||||||
|
OAUTH2_TEST_USERS = ./users.yaml;
|
||||||
|
OAUTH2_SERVER_PORT = 9443;
|
||||||
|
OAUTH2_DB_PORT = 9444;
|
||||||
|
shellHook = ''
|
||||||
|
${mkDB}
|
||||||
|
zsh
|
||||||
|
${killDB}
|
||||||
|
'';
|
||||||
|
};
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|||||||
8
killDB.sh
Executable file
8
killDB.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2024 UniWorX Systems
|
||||||
|
# SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
|
||||||
|
pg_ctl stop -D "${OAUTH2_PGDIR}"
|
||||||
|
rm -rvf "${OAUTH2_PGDIR}" "${OAUTH2_PGHOST}" "${OAUTH2_PGLOG}"
|
||||||
|
|
||||||
14
mkDB.sh
14
mkDB.sh
@ -3,10 +3,9 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
|
||||||
export OAUTH2_SERVER_PORT=9443
|
[[ -z "${OAUTH2_HBA}" || -z "${OAUTH2_DB_SCHEMA}" ]] && echo "oauth2: missing env vars for hba and/or schema" && exit 1
|
||||||
export OAUTH2_DB_PORT=9444
|
|
||||||
|
|
||||||
tmpdir=./database
|
tmpdir=${XDG_RUNTIME_DIR}/.oauth2-db
|
||||||
|
|
||||||
if [ ! -d "${tmpdir}" ]; then
|
if [ ! -d "${tmpdir}" ]; then
|
||||||
mkdir ${tmpdir}
|
mkdir ${tmpdir}
|
||||||
@ -19,15 +18,12 @@ pgSockDir=$(mktemp -d --tmpdir="${absdir}" postgresql.sock.XXXXXX)
|
|||||||
pgLogFile=$(mktemp --tmpdir="${absdir}" postgresql.XXXXXX.log)
|
pgLogFile=$(mktemp --tmpdir="${absdir}" postgresql.XXXXXX.log)
|
||||||
|
|
||||||
initdb --no-locale -D "${pgDir}"
|
initdb --no-locale -D "${pgDir}"
|
||||||
pg_ctl start -D "${pgDir}" -l "${pgLogFile}" -w -o "-k ${pgSockDir} -c listen_addresses='::' -c hba_file='hba_file' -p ${OAUTH2_DB_PORT} -h localhost -c unix_socket_permissions=0700 -c max_connections=10 -c session_preload_libraries=auto_explain -c auto_explain.log_min_duration=100ms"
|
pg_ctl start -D "${pgDir}" -l "${pgLogFile}" -w -o "-k ${pgSockDir} -c listen_addresses='::' -c hba_file='${OAUTH2_HBA}' -p ${OAUTH2_DB_PORT} -h localhost -c unix_socket_permissions=0700 -c max_connections=10 -c session_preload_libraries=auto_explain -c auto_explain.log_min_duration=100ms"
|
||||||
psql -h "${pgSockDir}" -p ${OAUTH2_DB_PORT} -f ./schema.sql postgres
|
psql -h "${pgSockDir}" -p ${OAUTH2_DB_PORT} -f "${OAUTH2_DB_SCHEMA}" postgres
|
||||||
|
|
||||||
printf "Postgres logfile is %s\nPostgres socket directory is %s\n" "${pgLogFile}" "${pgSockDir}"
|
printf "Postgres logfile is %s\nPostgres socket directory is %s\n" "${pgLogFile}" "${pgSockDir}"
|
||||||
|
|
||||||
export OAUTH2_PGHOST="${pgSockDir}"
|
export OAUTH2_PGHOST="${pgSockDir}"
|
||||||
export OAUTH2_PGLOG="${pgLogFile}"
|
export OAUTH2_PGLOG="${pgLogFile}"
|
||||||
|
export OAUTH2_PGDIR="${pgDir}"
|
||||||
|
|
||||||
zsh
|
|
||||||
|
|
||||||
pg_ctl stop -D "${pgDir}"
|
|
||||||
rm -rvf "${pgDir}" "${pgSockDir}" "${pgLogFile}"
|
|
||||||
|
|||||||
@ -91,6 +91,7 @@ executable oauth2-mock-server-exe
|
|||||||
, transformers
|
, transformers
|
||||||
, uuid
|
, uuid
|
||||||
, warp
|
, warp
|
||||||
|
, yaml
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
test-suite oauth2-mock-server-test
|
test-suite oauth2-mock-server-test
|
||||||
|
|||||||
@ -75,6 +75,7 @@ executables:
|
|||||||
- monad-logger
|
- monad-logger
|
||||||
- conduit
|
- conduit
|
||||||
- mtl
|
- mtl
|
||||||
|
- yaml
|
||||||
|
|
||||||
tests:
|
tests:
|
||||||
oauth2-mock-server-test:
|
oauth2-mock-server-test:
|
||||||
|
|||||||
@ -3,27 +3,36 @@
|
|||||||
--
|
--
|
||||||
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
|
||||||
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-}
|
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-}
|
||||||
|
|
||||||
module AuthCode
|
module AuthCode
|
||||||
( State(..)
|
( State(..)
|
||||||
, AuthState
|
, AuthState
|
||||||
, AuthRequest(..)
|
, AuthRequest(..)
|
||||||
, JWT(..)
|
, JWT(..)
|
||||||
|
, JWTWrapper(..)
|
||||||
, genUnencryptedCode
|
, genUnencryptedCode
|
||||||
, verify
|
, verify
|
||||||
|
, mkToken
|
||||||
|
, decodeToken
|
||||||
|
, renewToken
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import User
|
import User
|
||||||
|
|
||||||
import Data.Aeson
|
import Data.Aeson
|
||||||
|
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
||||||
|
import Data.Either (fromRight)
|
||||||
import Data.Map.Strict (Map)
|
import Data.Map.Strict (Map)
|
||||||
import Data.Maybe (isJust, fromMaybe)
|
import Data.Maybe (isJust, fromMaybe, fromJust)
|
||||||
import Data.Time.Clock
|
import Data.Time.Clock
|
||||||
import Data.Text (pack, replace, Text)
|
import Data.Text (pack, replace, Text, stripPrefix)
|
||||||
|
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
||||||
import Data.Text.Encoding.Base64
|
import Data.Text.Encoding.Base64
|
||||||
import Data.UUID
|
import Data.UUID
|
||||||
|
import Data.UUID.V4
|
||||||
|
|
||||||
|
import qualified Data.ByteString.Char8 as BS
|
||||||
import qualified Data.Map.Strict as M
|
import qualified Data.Map.Strict as M
|
||||||
|
|
||||||
import Control.Concurrent (forkIO, threadDelay)
|
import Control.Concurrent (forkIO, threadDelay)
|
||||||
@ -31,7 +40,12 @@ import Control.Concurrent.STM.TVar
|
|||||||
import Control.Monad (void, (>=>))
|
import Control.Monad (void, (>=>))
|
||||||
import Control.Monad.STM
|
import Control.Monad.STM
|
||||||
|
|
||||||
|
import Jose.Jwa
|
||||||
|
import Jose.Jwe
|
||||||
import Jose.Jwk (Jwk(..))
|
import Jose.Jwk (Jwk(..))
|
||||||
|
import Jose.Jwt hiding (decode, encode)
|
||||||
|
|
||||||
|
import Servant.API (FromHttpApiData(..))
|
||||||
|
|
||||||
|
|
||||||
data JWT = JWT
|
data JWT = JWT
|
||||||
@ -47,6 +61,31 @@ instance FromJSON JWT where
|
|||||||
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
|
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
|
||||||
|
|
||||||
|
|
||||||
|
data JWTWrapper = JWTW
|
||||||
|
{ acessToken :: String
|
||||||
|
, expiresIn :: NominalDiffTime
|
||||||
|
, refreshToken :: Maybe String
|
||||||
|
} deriving (Show)
|
||||||
|
|
||||||
|
instance ToJSON JWTWrapper where
|
||||||
|
toJSON (JWTW a e r) = object
|
||||||
|
[ "access_token" .= a
|
||||||
|
, "token_type" .= ("JWT" :: Text)
|
||||||
|
, "expires_in" .= fromEnum e
|
||||||
|
, "refresh_token" .= r ]
|
||||||
|
|
||||||
|
instance FromJSON JWTWrapper where
|
||||||
|
parseJSON (Object o) = JWTW
|
||||||
|
<$> o .: "access_token"
|
||||||
|
<*> o .: "expires_in"
|
||||||
|
<*> o .:? "refresh_token"
|
||||||
|
|
||||||
|
instance FromHttpApiData JWTWrapper where
|
||||||
|
parseHeader bs = case decode (fromStrict bs) of
|
||||||
|
Just x -> Right x
|
||||||
|
Nothing -> Left "Invalid JWT wrapper"
|
||||||
|
|
||||||
|
|
||||||
data AuthRequest user = AuthRequest
|
data AuthRequest user = AuthRequest
|
||||||
{ client :: String
|
{ client :: String
|
||||||
, codeExpiration :: NominalDiffTime
|
, codeExpiration :: NominalDiffTime
|
||||||
@ -81,12 +120,11 @@ genUnencryptedCode req url state = do
|
|||||||
then (False, s)
|
then (False, s)
|
||||||
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
|
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
|
||||||
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
|
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
|
||||||
|
where
|
||||||
|
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
||||||
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
expire code time state = void . forkIO $ do
|
||||||
expire code time state = void . forkIO $ do
|
threadDelay $ fromEnum time
|
||||||
threadDelay $ fromEnum time
|
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
||||||
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
|
||||||
|
|
||||||
|
|
||||||
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
|
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
|
||||||
@ -99,3 +137,45 @@ verify code mClientID state = do
|
|||||||
return $ case mData of
|
return $ case mData of
|
||||||
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
|
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
|
||||||
_ -> Nothing
|
_ -> Nothing
|
||||||
|
|
||||||
|
|
||||||
|
mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
||||||
|
mkToken u scopes state = do
|
||||||
|
pubKey <- atomically $ readTVar state >>= return . publicKey
|
||||||
|
now <- getCurrentTime
|
||||||
|
uuid <- nextRandom
|
||||||
|
let
|
||||||
|
lifetimeAT = 3600 :: NominalDiffTime -- TODO make configurable
|
||||||
|
lifetimeRT = nominalDay -- TODO make configurable
|
||||||
|
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
|
||||||
|
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
|
||||||
|
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
|
||||||
|
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
|
||||||
|
case encodedAT >> encodedRT of
|
||||||
|
Right _ -> do
|
||||||
|
let Jwt aToken = fromRight undefined encodedAT
|
||||||
|
Jwt rToken = fromRight undefined encodedRT
|
||||||
|
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) }
|
||||||
|
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
|
||||||
|
Left e -> error $ show e
|
||||||
|
|
||||||
|
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
|
||||||
|
decodeToken token state = do
|
||||||
|
prKey <- atomically $ readTVar state >>= return . privateKey
|
||||||
|
jwkDecode prKey $ encodeUtf8 token
|
||||||
|
|
||||||
|
renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper)
|
||||||
|
renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of
|
||||||
|
Just t -> decodeToken t state >>= \case
|
||||||
|
Right (Jwe (header, body)) -> do
|
||||||
|
let jwt = fromJust . decode @JWT $ fromStrict body
|
||||||
|
now <- getCurrentTime
|
||||||
|
if now <= expiration jwt then return Nothing else do
|
||||||
|
mUser <- atomically . stateTVar state $ \s ->
|
||||||
|
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
|
||||||
|
in (key, s { activeTokens = tokens })
|
||||||
|
case mUser of
|
||||||
|
Just (u, scopes) -> Just <$> mkToken u scopes state
|
||||||
|
Nothing -> return Nothing
|
||||||
|
Left _ -> return Nothing
|
||||||
|
Nothing -> return Nothing
|
||||||
|
|||||||
@ -22,28 +22,24 @@ import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
|
|||||||
import Control.Exception (bracket)
|
import Control.Exception (bracket)
|
||||||
import Control.Monad (unless, (>=>))
|
import Control.Monad (unless, (>=>))
|
||||||
import Control.Monad.IO.Class
|
import Control.Monad.IO.Class
|
||||||
|
import Control.Monad.Trans.Error (Error(..))
|
||||||
import Control.Monad.Trans.Reader
|
import Control.Monad.Trans.Reader
|
||||||
|
|
||||||
import Data.Aeson
|
import Data.Aeson
|
||||||
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
import Data.ByteString (fromStrict)
|
||||||
import Data.List (find, elemIndex)
|
import Data.List (find, elemIndex)
|
||||||
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
|
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
|
||||||
import Data.String (IsString (..))
|
import Data.String (IsString (..))
|
||||||
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
|
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
|
||||||
import Data.Text.Encoding.Base64
|
import Data.Text.Encoding.Base64
|
||||||
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
|
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
|
||||||
import Data.UUID.V4
|
|
||||||
|
|
||||||
import qualified Data.ByteString.Char8 as BS
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
|
|
||||||
import GHC.Read (readPrec, lexP)
|
import GHC.Read (readPrec, lexP)
|
||||||
|
|
||||||
import Jose.Jwa
|
import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId)
|
||||||
import Jose.Jwe
|
|
||||||
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
|
|
||||||
import Jose.Jwt hiding (decode, encode)
|
import Jose.Jwt hiding (decode, encode)
|
||||||
|
|
||||||
import Network.HTTP.Client (newManager, defaultManagerSettings)
|
import Network.HTTP.Client (newManager, defaultManagerSettings)
|
||||||
@ -188,9 +184,10 @@ codeServer = handleCreds
|
|||||||
---- Token Endpoint ----
|
---- Token Endpoint ----
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
|
newtype ACode = ACode String deriving (Show)
|
||||||
|
|
||||||
data ClientData = ClientData --TODO support other flows
|
data ClientData = ClientData --TODO support other flows
|
||||||
{ authCode :: String
|
{ authID :: Either ACode JWTWrapper
|
||||||
, clientID :: Maybe String
|
, clientID :: Maybe String
|
||||||
, clientSecret :: Maybe String
|
, clientSecret :: Maybe String
|
||||||
, redirect :: Maybe String
|
, redirect :: Maybe String
|
||||||
@ -203,29 +200,15 @@ instance FromHttpApiData AuthFlow where
|
|||||||
|
|
||||||
instance FromForm ClientData where
|
instance FromForm ClientData where
|
||||||
fromForm f = ClientData
|
fromForm f = ClientData
|
||||||
<$> ((parseUnique @AuthFlow "grant_type" f) *> parseUnique "code" f)
|
<$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f)
|
||||||
|
<|> (Right <$> parseUnique "refresh_token" f)))
|
||||||
<*> parseMaybe "client_id" f
|
<*> parseMaybe "client_id" f
|
||||||
<*> parseMaybe "client_secret" f
|
<*> parseMaybe "client_secret" f
|
||||||
<*> parseMaybe "redirect_uri" f
|
<*> parseMaybe "redirect_uri" f
|
||||||
|
|
||||||
|
instance Error Text where
|
||||||
|
strMsg = pack
|
||||||
|
|
||||||
data JWTWrapper = JWTW
|
|
||||||
{ token :: String
|
|
||||||
, expiresIn :: NominalDiffTime
|
|
||||||
} deriving (Show)
|
|
||||||
|
|
||||||
instance ToJSON JWTWrapper where
|
|
||||||
toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e]
|
|
||||||
|
|
||||||
instance FromJSON JWTWrapper where
|
|
||||||
parseJSON (Object o) = JWTW
|
|
||||||
<$> o .: "access_token"
|
|
||||||
<*> o .: "expires_in"
|
|
||||||
|
|
||||||
instance FromHttpApiData JWTWrapper where
|
|
||||||
parseHeader bs = case decode (fromStrict bs) of
|
|
||||||
Just x -> Right x
|
|
||||||
Nothing -> Left "Invalid JWT wrapper"
|
|
||||||
|
|
||||||
type Token = "token"
|
type Token = "token"
|
||||||
:> ReqBody '[FormUrlEncoded] ClientData
|
:> ReqBody '[FormUrlEncoded] ClientData
|
||||||
@ -240,30 +223,20 @@ tokenEndpoint = provideToken
|
|||||||
unless (isNothing (clientID client >> clientSecret client)
|
unless (isNothing (clientID client >> clientSecret client)
|
||||||
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
|
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
|
||||||
throwError $ err500 { errBody = "Invalid client" }
|
throwError $ err500 { errBody = "Invalid client" }
|
||||||
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
|
case authID client of
|
||||||
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
|
Left (ACode authCode) -> do
|
||||||
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
mUser <- asks (verify (pack authCode) (clientID client)) >>= liftIO -- TODO verify redirect url here
|
||||||
let (user, scopes) = fromJust mUser
|
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
|
||||||
token <- asks (mkToken @user @userData user scopes) >>= liftIO
|
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
||||||
liftIO . putStrLn $ "token: " ++ show token
|
let (user, scopes) = fromJust mUser
|
||||||
return token
|
token <- asks (mkToken @user user scopes) >>= liftIO
|
||||||
|
liftIO . putStrLn $ "token: " ++ show token
|
||||||
|
return token
|
||||||
mkToken :: forall user userData . UserData user userData
|
Right jwtw -> do
|
||||||
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
mToken <- asks (renewToken @user jwtw) >>= liftIO
|
||||||
mkToken u scopes state = do
|
case mToken of
|
||||||
pubKey <- atomically $ readTVar state >>= return . publicKey
|
Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token
|
||||||
now <- getCurrentTime
|
Nothing -> throwError $ err500 { errBody = "Invalid refresh token" }
|
||||||
uuid <- nextRandom
|
|
||||||
let
|
|
||||||
lifetime = nominalDay / 24 -- TODO make configurable
|
|
||||||
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
|
|
||||||
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
|
|
||||||
case encoded of
|
|
||||||
Right (Jwt token) -> do
|
|
||||||
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
|
|
||||||
return $ JWTW (BS.unpack token) lifetime
|
|
||||||
Left e -> error $ show e
|
|
||||||
|
|
||||||
|
|
||||||
----------------------
|
----------------------
|
||||||
@ -291,8 +264,8 @@ userEndpoint = handleUserData
|
|||||||
handleUserData :: Text -> AuthHandler user (Maybe userData)
|
handleUserData :: Text -> AuthHandler user (Maybe userData)
|
||||||
handleUserData jwtw = do
|
handleUserData jwtw = do
|
||||||
let mToken = stripPrefix "Bearer " jwtw
|
let mToken = stripPrefix "Bearer " jwtw
|
||||||
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"}
|
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" }
|
||||||
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
|
token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO
|
||||||
liftIO $ putStrLn "decoded token:" >> print token
|
liftIO $ putStrLn "decoded token:" >> print token
|
||||||
case token of
|
case token of
|
||||||
Left e -> throwError $ err500 { errBody = fromString $ show e }
|
Left e -> throwError $ err500 { errBody = fromString $ show e }
|
||||||
@ -306,11 +279,6 @@ userEndpoint = handleUserData
|
|||||||
Nothing -> throwError $ err500 { errBody = "Unknown token" }
|
Nothing -> throwError $ err500 { errBody = "Unknown token" }
|
||||||
|
|
||||||
|
|
||||||
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
|
|
||||||
decodeToken token state = do
|
|
||||||
prKey <- atomically $ readTVar state >>= return . privateKey
|
|
||||||
jwkDecode prKey $ encodeUtf8 token
|
|
||||||
|
|
||||||
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
|
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
|
||||||
userListEndpoint = handleUserData
|
userListEndpoint = handleUserData
|
||||||
where
|
where
|
||||||
|
|||||||
0
users.yaml
Normal file
0
users.yaml
Normal file
Loading…
Reference in New Issue
Block a user