diff --git a/.gitignore b/.gitignore index 3539501..fe68769 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .stack-work/ *~ -database/* \ No newline at end of file +database/* +result diff --git a/app/UniWorX.hs b/app/UniWorX.hs index aeabd28..f373ef5 100644 --- a/app/UniWorX.hs +++ b/app/UniWorX.hs @@ -18,7 +18,7 @@ MultiParamTypeClasses, RecordWildCards #-} -module UniWorX (User(..), initDB, testUsers) where +module UniWorX (User(..), initDB) where import User @@ -33,6 +33,7 @@ import Conduit (ResourceT) import Data.Map (Map(..)) import Data.String (IsString(..)) import Data.Text (Text(..)) +import Data.Yaml (decodeFileThrow, FromJSON(..), Value(..), (.:), (.:?)) import qualified Data.Map as M import qualified Data.Text as T @@ -45,18 +46,49 @@ import System.Environment (lookupEnv) share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase| User - name Text - email Text + firstName Text + surname Text + email Text + matricNumber Text Maybe + title Text Maybe + sex Text Maybe + birthday Text Maybe + telephone Text Maybe + mobile Text Maybe + compPersNumber Text Maybe + compDepartment Text Maybe + postAddress Text Maybe deriving Eq Show |] -testUsers :: [User] -- TODO move to db -testUsers = - [ User "Fallback User" "foo@bar.com" - , User "Tina Tester" "tester@campus.lmu.de" - , User "Max Muster" "m@m.mm" ] +instance FromJSON User where + parseJSON (Object o) = User + <$> o .: "userFirstName" + <*> o .: "userSurname" + <*> o .: "userEmail" + <*> o .:? "userMatrikelnummer" + <*> o .:? "userTitle" + <*> o .:? "userSex" + <*> o .:? "userBirthday" + <*> o .:? "userTelephone" + <*> o .:? "userMobile" + <*> o .:? "userCompanyPersonalNumber" + <*> o .:? "userCompanyDepartment" + <*> o .:? "userPostAddress" + parseJSON _ = error "Oauth2 Mock Server: invalid test user format" -runDB :: ReaderT SqlBackend (NoLoggingT (ResourceT IO)) a -> IO a +data TestUserSpec = TestUsers + { specialUsers :: [Map Text User] + , randomUsers :: Map Text [Maybe Text] + } deriving (Show) + +instance FromJSON TestUserSpec where + parseJSON (Object o) = TestUsers <$> o .: "special-users" <*> o .: "random-users" + parseJSON _ = error "Oauth2 Mock Server: invalid test user format" + +type DB = ReaderT SqlBackend (NoLoggingT (ResourceT IO)) + +runDB :: DB a -> IO a runDB action = do Just port <- lookupEnv "OAUTH2_DB_PORT" -- >>= \p -> return $ p <|> Just "9444" Just host <- lookupEnv "OAUTH2_PGHOST" @@ -64,19 +96,39 @@ runDB action = do runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ flip runSqlPersistMPool pool action initDB :: IO () -initDB = runDB $ do - runMigration migrateAll - forM_ testUsers $ void . insert +initDB = do + Just testUserFile <- lookupEnv "OAUTH2_TEST_USERS" + runDB $ do + runMigration migrateAll + testUsers <- decodeFileThrow @DB @TestUserSpec testUserFile + liftIO . putStrLn $ "the test users:\n" ++ show testUsers + let users = M.elems . mconcat $ specialUsers testUsers + forM_ users $ void . insert + instance UserData (Entity User) (Map Text Text) where data Scope (Entity User) = ID | Profile deriving (Show, Read, Eq) readScope = read showScope = show userScope (Entity _ User{..}) ID = M.singleton "id" userEmail - userScope (Entity _ User{..}) Profile = M.fromList [("name", userName), ("email", userEmail)] + userScope (Entity _ User{..}) Profile = M.fromList [(key, val) | (key, Just val) <- + [ ("firstName", Just userFirstName) + , ("surname", Just userSurname) + , ("email", Just userEmail) + , ("matriculationNumber", userMatricNumber) + , ("title", userTitle) + , ("sex", userSex) + , ("birthday", userBirthday) + , ("telephone", userTelephone) + , ("mobile", userMobile) + , ("companyPersonalNumber", userCompPersNumber) + , ("companyDepartment", userCompDepartment) + , ("postAddress", userPostAddress) + ]] lookupUser email _ = runDB $ do user <- selectList [UserEmail ==. email] [] case user of [entity] -> return $ Just entity [] -> return Nothing - _ -> error "Ambiguous User." + _ -> error "Oauth2 Mock Server: Ambiguous User." + diff --git a/flake.nix b/flake.nix index a36bdc8..4e81d4e 100644 --- a/flake.nix +++ b/flake.nix @@ -29,16 +29,56 @@ with haskell.packages."ghc927"; [ ghc haskell-language-server ] ) ); + libPath = pkgs.lib.makeLibraryPath buildInputs; + oms = pkgs.stdenv.mkDerivation { + inherit buildInputs; + inherit name; + pname = name; + src = ./.; + # dontUnpack = true; + buildPhase = '' + HOME=$out + LD_LIBRARY_PATH=${libPath} + mkdir -p $HOME/.stack + stack build --verbose + ''; + installPhase = '' + mkdir -p $out/bin + mv .stack-work/install/${system}/*/*/bin/${name}-exe $out/bin/${name} + echo "moved" + ''; + }; + mkDB = builtins.readFile ./mkDB.sh; + killDB = builtins.readFile ./killDB.sh; in { - packages.${system}.${name} = nixpkgs.legacyPackages.${system}.${name}; - packages.${system}.default = self.packages.${system}.${name}; + packages.${system} = { + ${name} = oms; # nixpkgs.legacyPackages.${system}.${name}; + mkOauth2DB = pkgs.writeScriptBin "mkOauth2DB" '' + #!${pkgs.zsh}/bin/zsh -e + ${mkDB} + ''; + killOauth2DB = pkgs.writeScriptBin "killOauth2DB" '' + #!${pkgs.zsh}/bin/zsh -e + ${killDB} + ''; + default = self.packages.${system}.${name}; + }; - devShells.${system}.default = pkgs.mkShell { - buildInputs = buildInputs; - LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs; - shellHook = builtins.readFile ./mkDB.sh; + devShells.${system}.default = pkgs.mkShell { + buildInputs = buildInputs ++ (with self.packages.${system}; [mkOauth2DB killOauth2DB]); + LD_LIBRARY_PATH = libPath; + OAUTH2_HBA = ./hba_file; + OAUTH2_DB_SCHEMA = ./schema.sql; + OAUTH2_TEST_USERS = ./users.yaml; + OAUTH2_SERVER_PORT = 9443; + OAUTH2_DB_PORT = 9444; + shellHook = '' + ${mkDB} + zsh + ${killDB} + ''; + }; }; - }; -} + } diff --git a/killDB.sh b/killDB.sh new file mode 100755 index 0000000..dd10275 --- /dev/null +++ b/killDB.sh @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2024 UniWorX Systems +# SPDX-FileContributor: David Mosbach +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +pg_ctl stop -D "${OAUTH2_PGDIR}" +rm -rvf "${OAUTH2_PGDIR}" "${OAUTH2_PGHOST}" "${OAUTH2_PGLOG}" + diff --git a/mkDB.sh b/mkDB.sh index 8893f65..ab4c4f7 100755 --- a/mkDB.sh +++ b/mkDB.sh @@ -3,10 +3,9 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later -export OAUTH2_SERVER_PORT=9443 -export OAUTH2_DB_PORT=9444 +[[ -z "${OAUTH2_HBA}" || -z "${OAUTH2_DB_SCHEMA}" ]] && echo "oauth2: missing env vars for hba and/or schema" && exit 1 -tmpdir=./database +tmpdir=${XDG_RUNTIME_DIR}/.oauth2-db if [ ! -d "${tmpdir}" ]; then mkdir ${tmpdir} @@ -19,15 +18,12 @@ pgSockDir=$(mktemp -d --tmpdir="${absdir}" postgresql.sock.XXXXXX) pgLogFile=$(mktemp --tmpdir="${absdir}" postgresql.XXXXXX.log) initdb --no-locale -D "${pgDir}" -pg_ctl start -D "${pgDir}" -l "${pgLogFile}" -w -o "-k ${pgSockDir} -c listen_addresses='::' -c hba_file='hba_file' -p ${OAUTH2_DB_PORT} -h localhost -c unix_socket_permissions=0700 -c max_connections=10 -c session_preload_libraries=auto_explain -c auto_explain.log_min_duration=100ms" -psql -h "${pgSockDir}" -p ${OAUTH2_DB_PORT} -f ./schema.sql postgres +pg_ctl start -D "${pgDir}" -l "${pgLogFile}" -w -o "-k ${pgSockDir} -c listen_addresses='::' -c hba_file='${OAUTH2_HBA}' -p ${OAUTH2_DB_PORT} -h localhost -c unix_socket_permissions=0700 -c max_connections=10 -c session_preload_libraries=auto_explain -c auto_explain.log_min_duration=100ms" +psql -h "${pgSockDir}" -p ${OAUTH2_DB_PORT} -f "${OAUTH2_DB_SCHEMA}" postgres printf "Postgres logfile is %s\nPostgres socket directory is %s\n" "${pgLogFile}" "${pgSockDir}" export OAUTH2_PGHOST="${pgSockDir}" export OAUTH2_PGLOG="${pgLogFile}" +export OAUTH2_PGDIR="${pgDir}" -zsh - -pg_ctl stop -D "${pgDir}" -rm -rvf "${pgDir}" "${pgSockDir}" "${pgLogFile}" diff --git a/oauth2-mock-server.cabal b/oauth2-mock-server.cabal index edafa50..37004ba 100644 --- a/oauth2-mock-server.cabal +++ b/oauth2-mock-server.cabal @@ -91,6 +91,7 @@ executable oauth2-mock-server-exe , transformers , uuid , warp + , yaml default-language: Haskell2010 test-suite oauth2-mock-server-test diff --git a/package.yaml b/package.yaml index 3056d93..158730d 100644 --- a/package.yaml +++ b/package.yaml @@ -75,6 +75,7 @@ executables: - monad-logger - conduit - mtl + - yaml tests: oauth2-mock-server-test: diff --git a/src/AuthCode.hs b/src/AuthCode.hs index 5bb5579..101b232 100644 --- a/src/AuthCode.hs +++ b/src/AuthCode.hs @@ -3,27 +3,36 @@ -- -- SPDX-License-Identifier: AGPL-3.0-or-later -{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-} +{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-} module AuthCode ( State(..) , AuthState , AuthRequest(..) , JWT(..) +, JWTWrapper(..) , genUnencryptedCode , verify +, mkToken +, decodeToken +, renewToken ) where import User import Data.Aeson +import Data.ByteString (ByteString (..), fromStrict, toStrict) +import Data.Either (fromRight) import Data.Map.Strict (Map) -import Data.Maybe (isJust, fromMaybe) +import Data.Maybe (isJust, fromMaybe, fromJust) import Data.Time.Clock -import Data.Text (pack, replace, Text) +import Data.Text (pack, replace, Text, stripPrefix) +import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding.Base64 import Data.UUID +import Data.UUID.V4 +import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as M import Control.Concurrent (forkIO, threadDelay) @@ -31,7 +40,12 @@ import Control.Concurrent.STM.TVar import Control.Monad (void, (>=>)) import Control.Monad.STM +import Jose.Jwa +import Jose.Jwe import Jose.Jwk (Jwk(..)) +import Jose.Jwt hiding (decode, encode) + +import Servant.API (FromHttpApiData(..)) data JWT = JWT @@ -47,6 +61,31 @@ instance FromJSON JWT where parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti" +data JWTWrapper = JWTW + { acessToken :: String + , expiresIn :: NominalDiffTime + , refreshToken :: Maybe String + } deriving (Show) + +instance ToJSON JWTWrapper where + toJSON (JWTW a e r) = object + [ "access_token" .= a + , "token_type" .= ("JWT" :: Text) + , "expires_in" .= fromEnum e + , "refresh_token" .= r ] + +instance FromJSON JWTWrapper where + parseJSON (Object o) = JWTW + <$> o .: "access_token" + <*> o .: "expires_in" + <*> o .:? "refresh_token" + +instance FromHttpApiData JWTWrapper where + parseHeader bs = case decode (fromStrict bs) of + Just x -> Right x + Nothing -> Left "Invalid JWT wrapper" + + data AuthRequest user = AuthRequest { client :: String , codeExpiration :: NominalDiffTime @@ -81,12 +120,11 @@ genUnencryptedCode req url state = do then (False, s) else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes }) if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing - - -expire :: Text -> NominalDiffTime -> AuthState user -> IO () -expire code time state = void . forkIO $ do - threadDelay $ fromEnum time - atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } + where + expire :: Text -> NominalDiffTime -> AuthState user -> IO () + expire code time state = void . forkIO $ do + threadDelay $ fromEnum time + atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user])) @@ -99,3 +137,45 @@ verify code mClientID state = do return $ case mData of Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing _ -> Nothing + + +mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper +mkToken u scopes state = do + pubKey <- atomically $ readTVar state >>= return . publicKey + now <- getCurrentTime + uuid <- nextRandom + let + lifetimeAT = 3600 :: NominalDiffTime -- TODO make configurable + lifetimeRT = nominalDay -- TODO make configurable + at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid + rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid + encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at) + encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt) + case encodedAT >> encodedRT of + Right _ -> do + let Jwt aToken = fromRight undefined encodedAT + Jwt rToken = fromRight undefined encodedRT + atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) } + return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken) + Left e -> error $ show e + +decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent) +decodeToken token state = do + prKey <- atomically $ readTVar state >>= return . privateKey + jwkDecode prKey $ encodeUtf8 token + +renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper) +renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of + Just t -> decodeToken t state >>= \case + Right (Jwe (header, body)) -> do + let jwt = fromJust . decode @JWT $ fromStrict body + now <- getCurrentTime + if now <= expiration jwt then return Nothing else do + mUser <- atomically . stateTVar state $ \s -> + let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens + in (key, s { activeTokens = tokens }) + case mUser of + Just (u, scopes) -> Just <$> mkToken u scopes state + Nothing -> return Nothing + Left _ -> return Nothing + Nothing -> return Nothing diff --git a/src/Server.hs b/src/Server.hs index 2e1f016..922d66a 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -22,28 +22,24 @@ import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar) import Control.Exception (bracket) import Control.Monad (unless, (>=>)) import Control.Monad.IO.Class +import Control.Monad.Trans.Error (Error(..)) import Control.Monad.Trans.Reader import Data.Aeson -import Data.ByteString (ByteString (..), fromStrict, toStrict) +import Data.ByteString (fromStrict) import Data.List (find, elemIndex) import Data.Maybe (fromMaybe, fromJust, isJust, isNothing) import Data.String (IsString (..)) import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words) import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding.Base64 import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime) -import Data.UUID.V4 -import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as Map import GHC.Read (readPrec, lexP) -import Jose.Jwa -import Jose.Jwe -import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId) +import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId) import Jose.Jwt hiding (decode, encode) import Network.HTTP.Client (newManager, defaultManagerSettings) @@ -188,9 +184,10 @@ codeServer = handleCreds ---- Token Endpoint ---- ---------------------- +newtype ACode = ACode String deriving (Show) data ClientData = ClientData --TODO support other flows - { authCode :: String + { authID :: Either ACode JWTWrapper , clientID :: Maybe String , clientSecret :: Maybe String , redirect :: Maybe String @@ -203,29 +200,15 @@ instance FromHttpApiData AuthFlow where instance FromForm ClientData where fromForm f = ClientData - <$> ((parseUnique @AuthFlow "grant_type" f) *> parseUnique "code" f) + <$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f) + <|> (Right <$> parseUnique "refresh_token" f))) <*> parseMaybe "client_id" f <*> parseMaybe "client_secret" f <*> parseMaybe "redirect_uri" f +instance Error Text where + strMsg = pack -data JWTWrapper = JWTW - { token :: String - , expiresIn :: NominalDiffTime - } deriving (Show) - -instance ToJSON JWTWrapper where - toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e] - -instance FromJSON JWTWrapper where - parseJSON (Object o) = JWTW - <$> o .: "access_token" - <*> o .: "expires_in" - -instance FromHttpApiData JWTWrapper where - parseHeader bs = case decode (fromStrict bs) of - Just x -> Right x - Nothing -> Left "Invalid JWT wrapper" type Token = "token" :> ReqBody '[FormUrlEncoded] ClientData @@ -240,30 +223,20 @@ tokenEndpoint = provideToken unless (isNothing (clientID client >> clientSecret client) || Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) . throwError $ err500 { errBody = "Invalid client" } - mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here - unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" } - -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} - let (user, scopes) = fromJust mUser - token <- asks (mkToken @user @userData user scopes) >>= liftIO - liftIO . putStrLn $ "token: " ++ show token - return token - - -mkToken :: forall user userData . UserData user userData - => user -> [Scope user] -> AuthState user -> IO JWTWrapper -mkToken u scopes state = do - pubKey <- atomically $ readTVar state >>= return . publicKey - now <- getCurrentTime - uuid <- nextRandom - let - lifetime = nominalDay / 24 -- TODO make configurable - jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid - encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt) - case encoded of - Right (Jwt token) -> do - atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) } - return $ JWTW (BS.unpack token) lifetime - Left e -> error $ show e + case authID client of + Left (ACode authCode) -> do + mUser <- asks (verify (pack authCode) (clientID client)) >>= liftIO -- TODO verify redirect url here + unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" } + -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} + let (user, scopes) = fromJust mUser + token <- asks (mkToken @user user scopes) >>= liftIO + liftIO . putStrLn $ "token: " ++ show token + return token + Right jwtw -> do + mToken <- asks (renewToken @user jwtw) >>= liftIO + case mToken of + Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token + Nothing -> throwError $ err500 { errBody = "Invalid refresh token" } ---------------------- @@ -291,8 +264,8 @@ userEndpoint = handleUserData handleUserData :: Text -> AuthHandler user (Maybe userData) handleUserData jwtw = do let mToken = stripPrefix "Bearer " jwtw - unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"} - token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO + unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" } + token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO liftIO $ putStrLn "decoded token:" >> print token case token of Left e -> throwError $ err500 { errBody = fromString $ show e } @@ -306,11 +279,6 @@ userEndpoint = handleUserData Nothing -> throwError $ err500 { errBody = "Unknown token" } -decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent) -decodeToken token state = do - prKey <- atomically $ readTVar state >>= return . privateKey - jwkDecode prKey $ encodeUtf8 token - userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData) userListEndpoint = handleUserData where diff --git a/users.yaml b/users.yaml new file mode 100644 index 0000000..e69de29