-- SPDX-FileCopyrightText: 2024 UniWorX Systems -- SPDX-FileContributor: David Mosbach -- -- SPDX-License-Identifier: AGPL-3.0-or-later {-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-} module AuthCode ( State(..) , AuthState , AuthRequest(..) , JWT(..) , JWTWrapper(..) , genUnencryptedCode , verify , mkToken , decodeToken , renewToken ) where import User import Data.Aeson import Data.ByteString (ByteString (..), fromStrict, toStrict) import Data.Either (fromRight) import Data.Map.Strict (Map) import Data.Maybe (isJust, fromMaybe, fromJust) import Data.Time.Clock import Data.Text (pack, replace, Text, stripPrefix) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding.Base64 import Data.UUID import Data.UUID.V4 import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as M import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.STM.TVar import Control.Monad (void, (>=>)) import Control.Monad.STM import Jose.Jwa import Jose.Jwe import Jose.Jwk (Jwk(..)) import Jose.Jwt hiding (decode, encode) import Servant.API (FromHttpApiData(..)) data JWT = JWT { issuer :: Text , expiration :: UTCTime , jti :: UUID } deriving (Show, Eq) instance ToJSON JWT where toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j] instance FromJSON JWT where parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti" data JWTWrapper = JWTW { acessToken :: String , expiresIn :: NominalDiffTime , refreshToken :: Maybe String } deriving (Show) instance ToJSON JWTWrapper where toJSON (JWTW a e r) = object [ "access_token" .= a , "token_type" .= ("JWT" :: Text) , "expires_in" .= fromEnum e , "refresh_token" .= r ] instance FromJSON JWTWrapper where parseJSON (Object o) = JWTW <$> o .: "access_token" <*> o .: "expires_in" <*> o .:? "refresh_token" instance FromHttpApiData JWTWrapper where parseHeader bs = case decode (fromStrict bs) of Just x -> Right x Nothing -> Left "Invalid JWT wrapper" data AuthRequest user = AuthRequest { client :: String , codeExpiration :: NominalDiffTime , user :: user , scopes :: [Scope user] } data State user = State { activeCodes :: Map Text (AuthRequest user) , activeTokens :: Map UUID (user, [Scope user]) , publicKey :: Jwk , privateKey :: Jwk } type AuthState user = TVar (State user) genUnencryptedCode :: AuthRequest user -> String -> AuthState user -> IO (Maybe Text) genUnencryptedCode req url state = do now <- getCurrentTime let expiresAt = req.codeExpiration `addUTCTime` now simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt success <- atomically . stateTVar state $ \s -> let mEntry = M.lookup simpleCode s.activeCodes in if isJust mEntry then (False, s) else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes }) if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing where expire :: Text -> NominalDiffTime -> AuthState user -> IO () expire code time state = void . forkIO $ do threadDelay $ fromEnum time atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user])) verify code mClientID state = do now <- getCurrentTime mData <- atomically $ do result <- (readTVar >=> return . M.lookup code . activeCodes) state modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } return result return $ case mData of Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing _ -> Nothing mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper mkToken u scopes state = do pubKey <- atomically $ readTVar state >>= return . publicKey now <- getCurrentTime uuid <- nextRandom let lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable lifetimeRT = nominalDay -- TODO make configurable at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at) encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt) case encodedAT >> encodedRT of Right _ -> do let Jwt aToken = fromRight undefined encodedAT Jwt rToken = fromRight undefined encodedRT atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) } return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken) Left e -> error $ show e decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent) decodeToken token state = do prKey <- atomically $ readTVar state >>= return . privateKey jwkDecode prKey $ encodeUtf8 token renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper) renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of Just t -> decodeToken t state >>= \case Right (Jwe (header, body)) -> do let jwt = fromJust . decode @JWT $ fromStrict body now <- getCurrentTime if now <= expiration jwt then return Nothing else do mUser <- atomically . stateTVar state $ \s -> let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens in (key, s { activeTokens = tokens }) case mUser of Just (u, scopes) -> Just <$> mkToken u scopes state Nothing -> return Nothing Left _ -> return Nothing Nothing -> return Nothing