-- SPDX-FileCopyrightText: 2024 UniWorX Systems -- SPDX-FileContributor: David Mosbach -- -- SPDX-License-Identifier: AGPL-3.0-or-later {-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase, DeriveGeneric, AllowAmbiguousTypes #-} module AuthCode ( State(..) , AuthState , AuthRequest(..) , TokenParams(..) , JWT(..) , JWTWrapper(..) , genUnencryptedCode , verify , mkToken , decodeToken , renewToken ) where import Prelude hiding (exp) import User import Data.Aeson import Data.Bool (bool) import Data.ByteString (ByteString (..), fromStrict, toStrict) import Data.Either (fromRight) import Data.List ((\\)) import Data.Map.Strict (Map) import Data.Maybe (isJust, fromMaybe, fromJust, catMaybes) import Data.Time.Calendar import Data.Time.Clock import Data.Text (pack, replace, Text, stripPrefix) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding.Base64 import Data.UUID hiding (null) import Data.UUID.V4 import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as M import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.STM.TVar import Control.Monad (void, (>=>)) import Control.Monad.STM import GHC.Generics import Jose.Jwa import Jose.Jwe import Jose.Jwk (Jwk(..)) import Jose.Jwt hiding (decode, encode) import qualified Jose.Jws as Jws import Servant.API (FromHttpApiData(..)) import System.Environment (getEnv) -------------- ---- Tokens ---- -------------- data JWT = JWT { issuer :: Text , expiration :: UTCTime , jti :: UUID } deriving (Show, Eq) instance ToJSON JWT where toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j] instance FromJSON JWT where parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti" data IDToken = IDT { iss :: Text , sub :: Text , aud :: [Text] , exp :: NominalDiffTime , iat :: NominalDiffTime , auth_time :: Maybe NominalDiffTime , nonce :: Maybe Text } deriving (Generic, Show) instance ToJSON IDToken instance FromJSON IDToken data JWTWrapper = JWTW { acessToken :: String , expiresIn :: NominalDiffTime , refreshToken :: Maybe String , idToken :: Maybe String } deriving (Show) instance ToJSON JWTWrapper where toJSON (JWTW a e r i) = object [ "access_token" .= a , "token_type" .= ("JWT" :: Text) , "expires_in" .= fromEnum e , "refresh_token" .= r , "id_token" .= i ] instance FromJSON JWTWrapper where parseJSON (Object o) = JWTW <$> o .: "access_token" <*> o .: "expires_in" <*> o .:? "refresh_token" <*> o .:? "id_token" instance FromHttpApiData JWTWrapper where parseHeader bs = case decode (fromStrict bs) of Just x -> Right x Nothing -> Left "Invalid JWT wrapper" ------------- ---- State ---- ------------- data AuthRequest user = AuthRequest { client :: String , codeExpiration :: NominalDiffTime , user :: user , scopes :: [Scope' user] , rNonce :: Maybe Text } type TokenParams user = (user, [Scope' user], Maybe Text) data State user = State { activeCodes :: Map Text (AuthRequest user) , activeTokens :: Map UUID (TokenParams user) , publicKey :: Jwk , privateKey :: Jwk } type AuthState user = TVar (State user) ----------------- ---- Functions ---- ----------------- genUnencryptedCode :: AuthRequest user -> String -> AuthState user -> IO (Maybe Text) genUnencryptedCode req url state = do now <- getCurrentTime let expiresAt = req.codeExpiration `addUTCTime` now simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt success <- atomically . stateTVar state $ \s -> let mEntry = M.lookup simpleCode s.activeCodes in if isJust mEntry then (False, s) else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes }) if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing where expire :: Text -> NominalDiffTime -> AuthState user -> IO () expire code time state = void . forkIO $ do threadDelay $ fromEnum time atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (TokenParams user)) verify code mClientID state = do now <- getCurrentTime mData <- atomically $ do result <- (readTVar >=> return . M.lookup code . activeCodes) state modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } return result return $ case mData of Just (AuthRequest clientID' _ u s n) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s, n) else Nothing _ -> Nothing mkToken :: forall user userData . UserData user userData => TokenParams user -> Maybe Text -- client_id -> AuthState user -> IO JWTWrapper mkToken (u, scopes, nonce) clientID state = do (pubKey, privKey) <- atomically $ readTVar state >>= return . ((,) <$> publicKey <*> privateKey) now <- getCurrentTime uuid <- nextRandom port <- pack <$> getEnv "OAUTH2_SERVER_PORT" let lifetimeAT = 3600 :: NominalDiffTime -- TODO make configurable lifetimeRT = nominalDay -- TODO make configurable lifetimeIT = 3600 :: NominalDiffTime -- TODO make configurable itRefDate = UTCTime (fromGregorian 1970 1 1) 0 at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid it = IDT { iss = "http://localhost:" <> port -- TODO maybe make configurable , sub = pack . show $ userID @user @userData u , aud = catMaybes [clientID] , exp = (lifetimeIT `addUTCTime` now) `diffUTCTime` itRefDate , iat = now `diffUTCTime` itRefDate , auth_time = Just $ now `diffUTCTime` itRefDate , nonce = nonce } encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at) encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt) encodedIT <- Jws.jwkEncode RS256 privKey (Nested . Jwt . toStrict $ encode it) case encodedAT >> encodedRT >> encodedIT of Right _ -> do let Jwt aToken = fromRight undefined encodedAT Jwt rToken = fromRight undefined encodedRT Jwt iToken = fromRight undefined encodedIT atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes, nonce) (activeTokens s) } return $ JWTW { acessToken = BS.unpack aToken , expiresIn = lifetimeAT , refreshToken = Just $ BS.unpack rToken , idToken = if Left OpenID `elem` scopes then Nothing else Just $ BS.unpack iToken } Left e -> error $ show e decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent) decodeToken token state = do key <- atomically $ readTVar state >>= return . privateKey jwkDecode key $ encodeUtf8 token renewToken :: forall user userData . UserData user userData => Text -- ^ token -> [Scope' user] -> Maybe Text -- ^ client_id -> AuthState user -> IO (Maybe JWTWrapper) -- TODO more descriptive failures renewToken t scopes clientID state = decodeToken t state >>= \case Right (Jwe (header, body)) -> do let jwt = fromJust . decode @JWT $ fromStrict body now <- getCurrentTime if now >= expiration jwt then return Nothing else do mUser <- atomically . stateTVar state $ \s -> let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens in (key, s { activeTokens = tokens }) case mUser of Just (u, scopes', nonce) -> bool (pure Nothing) (Just <$> mkToken @user @userData (u, scopes, nonce) clientID state) (null $ scopes \\ scopes') Nothing -> return Nothing Left _ -> return Nothing