oauth2-mock-server/src/Server.hs

346 lines
13 KiB
Haskell

{-# LANGUAGE DataKinds, TypeOperators, OverloadedStrings, ScopedTypeVariables, TypeApplications, RecordWildCards, AllowAmbiguousTypes #-}
module Server
{-( insecureOAuthMock'
, runMockServer
-- , runMockServer'
)-} where
import AuthCode
import User
import Control.Applicative ((<|>))
import Control.Concurrent
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
import Control.Exception (bracket)
import Control.Monad (unless, (>=>))
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader
import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.List (find, elemIndex)
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
import Data.String (IsString (..))
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as Map
import GHC.Read (readPrec, lexP)
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
import Jose.Jwt hiding (decode, encode)
import Network.HTTP.Client (newManager, defaultManagerSettings)
import Network.Wai.Handler.Warp
import Servant
import Servant.Client
import Servant.API
import Text.ParserCombinators.ReadPrec (look, pfail)
import qualified Text.Read.Lex as Lex
import Web.Internal.FormUrlEncoded (FromForm(..), parseUnique, parseMaybe)
testUsers :: [User] -- TODO move to db
testUsers =
[ User {name = "Fallback User", email = "foo@bar.com", password = "0000", uID = "0"}
, User {name = "Tina Tester", email = "t@t.tt", password = "1111", uID = "1"}
, User {name = "Max Muster", email = "m@m.mm", password = "2222", uID = "2"}]
data AuthClient = Client
{ ident :: Text
, secret :: Text
} deriving (Show, Eq)
trustedClients :: [AuthClient] -- TODO move to db
trustedClients = [Client "42" "shhh"]
data ResponseType = Code -- ^ authorisation code grant
| Token -- ^ implicit grant via access token
| IDToken -- ^ implicit grant via access token & ID token
deriving (Eq, Show)
instance Read ResponseType where
readPrec = do
Lex.Ident str <- lexP
Lex.EOF <- lexP
case str of
"code" -> return Code
"token" -> return Token
"id_token" -> return IDToken
_ -> pfail
------------------------------
---- Authorisation endpoint ----
------------------------------
type QScope = String
type QClient = String
type QResType = String
type QRedirect = String
type QState = Text
type QParam = QueryParam' [Required, Strict]
type Auth user userData = BasicAuth "login" user
:> "auth"
:> QParam "scope" QScope
:> QParam "client_id" QClient
:> QParam "response_type" QResType
:> QParam "redirect_uri" QRedirect
:> QueryParam "state" QState
:> Get '[JSON] userData
-- type Insert = "insert" :> Post '[JSON] User
type AuthHandler user = ReaderT (AuthState user) Handler
type AuthServer user a = ServerT a (AuthHandler user)
toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a
toHandler s h = runReaderT h s
authServer :: forall user userData . UserData user userData => AuthServer user (Auth user userData)
authServer = handleAuth
where
handleAuth :: user
-> QScope
-> QClient
-> QResType
-> QRedirect
-> Maybe QState
-> AuthHandler user userData
handleAuth u scopes client responseType url mState = do
unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client
throwError $ err404 { errBody = "Not a trusted client."}
let
responseType' = read @ResponseType responseType
scopes' = map (readScope @user @userData) $ words scopes
liftIO $ print responseType'
unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" }
mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') url) >>= liftIO
liftIO $ print mAuthCode
liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
redirect $ addParams url mAuthCode mState
redirect :: Maybe ByteString -> AuthHandler user userData
redirect (Just url) = liftIO (print url) >> throwError err303 { errHeaders = [("Location", url)]}
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
addParams :: String -> Maybe Text -> Maybe Text -> Maybe ByteString
addParams url Nothing _ = Nothing
addParams url (Just code) mState =
let qPos = fromMaybe (length url) $ elemIndex '?' url
(pre, post) = splitAt qPos url
rState = case mState of {Just s -> "&state=" ++ (unpack . replace "/" "%2F" $ replace "=" "%3D" s); Nothing -> ""}
post' = if not (null post) then '&' : tail post else post
in Just . fromString $ pre ++ "?code=" ++ (unpack code) ++ post' ++ rState
----------------------
---- Token Endpoint ----
----------------------
data ClientData = ClientData --TODO support other flows
{ authCode :: String
, clientID :: Maybe String
, clientSecret :: Maybe String
, redirect :: Maybe String
} deriving Show
data AuthFlow = AuthFlow
instance FromHttpApiData AuthFlow where
parseQueryParam "authorization_code" = Right AuthFlow
parseQueryParam x = Left x
instance FromForm ClientData where
fromForm f = ClientData
<$> ((parseUnique @AuthFlow "grant_type" f) *> parseUnique "code" f)
<*> parseMaybe "client_id" f
<*> parseMaybe "client_secret" f
<*> parseMaybe "redirect_uri" f
data JWTWrapper = JWTW
{ token :: String
, expiresIn :: NominalDiffTime
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
type Token = "token"
:> ReqBody '[FormUrlEncoded] ClientData
:> Post '[JSON] JWTWrapper
tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token
tokenEndpoint = provideToken
where
provideToken :: ClientData -> AuthHandler user JWTWrapper
provideToken client = do
unless (isNothing (clientID client >> clientSecret client)
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
throwError $ err500 { errBody = "Invalid client" }
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
let (user, scopes) = fromJust mUser
token <- asks (mkToken @user @userData user scopes) >>= liftIO
liftIO . putStrLn $ "token: " ++ show token
return token
mkToken :: forall user userData . UserData user userData
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetime = nominalDay / 24 -- TODO make configurable
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
case encoded of
Right (Jwt token) -> do
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack token) lifetime
Left e -> error $ show e
----------------------
---- Users Endpoint ----
----------------------
type Users = "users"
type HeaderR = Header' [Strict, Required]
type Me userData = Users
:> "me"
:> HeaderR "Authorization" Text
:> Get '[JSON] userData
type UserList userData = Users
:> "query"
:> HeaderR "Authorization" Text
:> Get '[JSON] [userData] -- TODO support query params
userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData)
userEndpoint = handleUserData
where
handleUserData :: Text -> AuthHandler user userData
handleUserData jwtw = do
let mToken = stripPrefix "Bearer " jwtw
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"}
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
liftIO $ putStrLn "decoded token:" >> print token
case token of
Left e -> throwError $ err500 { errBody = fromString $ show e }
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
-- TODO check if token grants access, then read logged in user from cookie
liftIO $ print jwt
mUser <- ask >>= liftIO . (atomically . readTVar >=> return . Map.lookup (jti jwt) . activeTokens)
case mUser of
Just (u, scopes) -> return . mconcat $ map (userScope @user @userData u) scopes
Nothing -> throwError $ err500 { errBody = "Unknown token" }
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
userListEndpoint = handleUserData
where
handleUserData :: Text -> AuthHandler user [userData]
handleUserData jwtw = do
undefined
-------------------
---- Server Main ----
-------------------
type Routing user userData = Auth user userData
:<|> Token
:<|> Me userData
:<|> UserList userData
routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData)
routing = authServer @user @userData
:<|> tokenEndpoint @user @userData
:<|> userEndpoint @user @userData
:<|> userListEndpoint @user @userData
exampleAuthServer :: AuthServer User (Routing User (Map.Map Text Text))
exampleAuthServer = routing
authAPI :: Proxy (Routing User (Map.Map Text Text))
authAPI = Proxy
-- insecureOAuthMock :: Application
-- insecureOAuthMock = authAPI `serve` exampleAuthServer
insecureOAuthMock' :: [User] -> AuthState User -> Application
insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler @User @(Map.Map Text Text) s) exampleAuthServer
where
c = authenticate testUsers :. EmptyContext
p = Proxy :: Proxy '[BasicAuthCheck User]
authenticate :: [User] -> BasicAuthCheck User
authenticate users = BasicAuthCheck $ \authData -> do
let
(uEmail, uPass) = (,) <$> (decodeUtf8 . basicAuthUsername) <*> (decodeUtf8 . basicAuthPassword) $ authData
case (find (\u -> email u == uEmail) users) of
Nothing -> return NoSuchUser
Just u -> return $ if uPass == password u then Authorized u else BadPassword
-- frontend :: BasicAuthData -> ClientM (Map.Map Text Text)
-- frontend ba = client authAPI ba "[ID]" "42" "code" ""
runMockServer :: Int -> IO ()
runMockServer port = do
state <- mkState @User @(Map.Map Text Text)
run port $ insecureOAuthMock' testUsers state
-- runMockServer' :: Int -> IO ()
-- runMockServer' port = do
-- mgr <- newManager defaultManagerSettings
-- state <- mkState
-- bracket (forkIO . run port $ insecureOAuthMock' testUsers state) killThread $ \_ ->
-- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port ""))
-- >>= print
mkState :: forall user userData . UserData user userData => IO (AuthState user)
mkState = do
(publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing
let
activeCodes = Map.empty
activeTokens = Map.empty
newTVarIO State{..}