mirror of
https://github.com/freckle/yesod-auth-oauth2.git
synced 2026-03-02 19:34:37 +01:00
Consolidate all errors, use onErrorHtml
Prior to this commit, some errors would be thrown (missing parameter, invalid state, incorrect approot) while others would be handled via the set-message-redirect approach (handshake failure, fetch-token failure, etc). This commit consolidates all of these cases into a single DispatchError type, and then uses MonadError (concretely ExceptT) to capture them all and handle them in one place ourselves. It then updates that handling to: - Use onErrorHtml onErrorHtml will, by default, set-message-redirect. That make this behavior neutral for users running defaults. For users that have customized this, it will be an improvement that all our error cases now respect it. - Provided a JSON representation of errors - Attach a random correlation identifier The last two were just nice-to-haves that were cheap to add once the code was in this state. Note that the use of MonadError requires a potentially "bad" orphan MonadUnliftIO instance for ExceptT, but I'd like to see that instance become a reality and think it needs some real-world experimentation to get there, so here I am.
This commit is contained in:
parent
cfcd8c5210
commit
8b3908ec91
@ -32,11 +32,13 @@ library:
|
|||||||
- http-types >=0.8 && <0.13
|
- http-types >=0.8 && <0.13
|
||||||
- memory
|
- memory
|
||||||
- microlens
|
- microlens
|
||||||
|
- mtl
|
||||||
- safe-exceptions
|
- safe-exceptions
|
||||||
- text >=0.7 && <2.0
|
- text >=0.7 && <2.0
|
||||||
- uri-bytestring
|
- uri-bytestring
|
||||||
- yesod-auth >=1.6.0 && <1.7
|
- yesod-auth >=1.6.0 && <1.7
|
||||||
- yesod-core >=1.6.0 && <1.7
|
- yesod-core >=1.6.0 && <1.7
|
||||||
|
- unliftio
|
||||||
|
|
||||||
executables:
|
executables:
|
||||||
yesod-auth-oauth2-example:
|
yesod-auth-oauth2-example:
|
||||||
|
|||||||
12
src/UnliftIO/Except.hs
Normal file
12
src/UnliftIO/Except.hs
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||||
|
|
||||||
|
module UnliftIO.Except
|
||||||
|
() where
|
||||||
|
|
||||||
|
import Control.Monad.Except
|
||||||
|
import UnliftIO
|
||||||
|
|
||||||
|
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||||
|
withRunInIO exceptToIO = ExceptT $ try $ do
|
||||||
|
withRunInIO $ \runInIO ->
|
||||||
|
exceptToIO (runInIO . (either throwIO pure <=< runExceptT))
|
||||||
@ -1,36 +1,30 @@
|
|||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TemplateHaskell #-}
|
|
||||||
{-# LANGUAGE TypeApplications #-}
|
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
module Yesod.Auth.OAuth2.Dispatch
|
module Yesod.Auth.OAuth2.Dispatch
|
||||||
( FetchToken
|
( FetchToken
|
||||||
, fetchAccessToken
|
, fetchAccessToken
|
||||||
, fetchAccessToken2
|
, fetchAccessToken2
|
||||||
, FetchCreds
|
, FetchCreds
|
||||||
, dispatchAuthRequest
|
, dispatchAuthRequest
|
||||||
)
|
) where
|
||||||
where
|
|
||||||
|
|
||||||
import Control.Exception.Safe
|
import Control.Monad.Except
|
||||||
import Control.Monad (unless, (<=<))
|
|
||||||
import Crypto.Random (getRandomBytes)
|
|
||||||
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
|
|
||||||
import Data.ByteString (ByteString)
|
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
import Network.HTTP.Conduit (Manager)
|
import Network.HTTP.Conduit (Manager)
|
||||||
import Network.OAuth.OAuth2
|
import Network.OAuth.OAuth2
|
||||||
import Network.OAuth.OAuth2.TokenRequest (Errors)
|
import Network.OAuth.OAuth2.TokenRequest (Errors)
|
||||||
|
import UnliftIO.Exception
|
||||||
import URI.ByteString.Extension
|
import URI.ByteString.Extension
|
||||||
import Yesod.Auth hiding (ServerError)
|
import Yesod.Auth hiding (ServerError)
|
||||||
|
import Yesod.Auth.OAuth2.DispatchError
|
||||||
import Yesod.Auth.OAuth2.ErrorResponse
|
import Yesod.Auth.OAuth2.ErrorResponse
|
||||||
import Yesod.Auth.OAuth2.Exception
|
import Yesod.Auth.OAuth2.Random
|
||||||
import Yesod.Core hiding (ErrorResponse)
|
import Yesod.Core hiding (ErrorResponse)
|
||||||
|
|
||||||
-- | How to fetch an @'OAuth2Token'@
|
-- | How to fetch an @'OAuth2Token'@
|
||||||
@ -53,9 +47,9 @@ dispatchAuthRequest
|
|||||||
-> [Text] -- ^ Path pieces
|
-> [Text] -- ^ Path pieces
|
||||||
-> AuthHandler m TypedContent
|
-> AuthHandler m TypedContent
|
||||||
dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] =
|
dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] =
|
||||||
dispatchForward name oauth2
|
handleDispatchError $ dispatchForward name oauth2
|
||||||
dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] =
|
dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] =
|
||||||
dispatchCallback name oauth2 getToken getCreds
|
handleDispatchError $ dispatchCallback name oauth2 getToken getCreds
|
||||||
dispatchAuthRequest _ _ _ _ _ _ = notFound
|
dispatchAuthRequest _ _ _ _ _ _ = notFound
|
||||||
|
|
||||||
-- | Handle @GET \/forward@
|
-- | Handle @GET \/forward@
|
||||||
@ -63,7 +57,11 @@ dispatchAuthRequest _ _ _ _ _ _ = notFound
|
|||||||
-- 1. Set a random CSRF token in our session
|
-- 1. Set a random CSRF token in our session
|
||||||
-- 2. Redirect to the Provider's authorization URL
|
-- 2. Redirect to the Provider's authorization URL
|
||||||
--
|
--
|
||||||
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
|
dispatchForward
|
||||||
|
:: (MonadError DispatchError m, MonadAuthHandler site m)
|
||||||
|
=> Text
|
||||||
|
-> OAuth2
|
||||||
|
-> m TypedContent
|
||||||
dispatchForward name oauth2 = do
|
dispatchForward name oauth2 = do
|
||||||
csrf <- setSessionCSRF $ tokenSessionKey name
|
csrf <- setSessionCSRF $ tokenSessionKey name
|
||||||
oauth2' <- withCallbackAndState name oauth2 csrf
|
oauth2' <- withCallbackAndState name oauth2 csrf
|
||||||
@ -76,75 +74,47 @@ dispatchForward name oauth2 = do
|
|||||||
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
|
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
|
||||||
--
|
--
|
||||||
dispatchCallback
|
dispatchCallback
|
||||||
:: Text
|
:: (MonadError DispatchError m, MonadAuthHandler site m)
|
||||||
|
=> Text
|
||||||
-> OAuth2
|
-> OAuth2
|
||||||
-> FetchToken
|
-> FetchToken
|
||||||
-> FetchCreds m
|
-> FetchCreds site
|
||||||
-> AuthHandler m TypedContent
|
-> m TypedContent
|
||||||
dispatchCallback name oauth2 getToken getCreds = do
|
dispatchCallback name oauth2 getToken getCreds = do
|
||||||
csrf <- verifySessionCSRF $ tokenSessionKey name
|
csrf <- verifySessionCSRF $ tokenSessionKey name
|
||||||
onErrorResponse $ oauth2HandshakeError name
|
onErrorResponse $ throwError . OAuth2HandshakeError
|
||||||
code <- requireGetParam "code"
|
code <- requireGetParam "code"
|
||||||
manager <- authHttpManager
|
manager <- authHttpManager
|
||||||
oauth2' <- withCallbackAndState name oauth2 csrf
|
oauth2' <- withCallbackAndState name oauth2 csrf
|
||||||
token <- errLeft $ getToken manager oauth2' $ ExchangeToken code
|
token <-
|
||||||
creds <- errLeft $ tryFetchCreds $ getCreds manager token
|
errLeft OAuth2ResultError $ getToken manager oauth2' $ ExchangeToken
|
||||||
|
code
|
||||||
|
creds <- errLeft id $ tryFetchCreds $ getCreds manager token
|
||||||
setCredsRedirect creds
|
setCredsRedirect creds
|
||||||
where
|
where
|
||||||
errLeft :: Show e => IO (Either e a) -> AuthHandler m a
|
errLeft
|
||||||
errLeft = either (unexpectedError name) pure <=< liftIO
|
:: (MonadIO m, MonadError e m) => (e' -> e) -> IO (Either e' a) -> m a
|
||||||
|
errLeft f = either (throwError . f) pure <=< liftIO
|
||||||
|
|
||||||
-- | Handle an OAuth2 @'ErrorResponse'@
|
tryFetchCreds :: IO a -> IO (Either DispatchError a)
|
||||||
--
|
|
||||||
-- These are things coming from the OAuth2 provider such an Invalid Grant or
|
|
||||||
-- Invalid Scope and /may/ be user-actionable. We've coded them to have an
|
|
||||||
-- @'erUserMessage'@ that we are comfortable displaying to the user as part of
|
|
||||||
-- the redirect, just in case.
|
|
||||||
--
|
|
||||||
oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a
|
|
||||||
oauth2HandshakeError name err = do
|
|
||||||
$(logError) $ "Handshake failure in " <> name <> " plugin: " <> tshow err
|
|
||||||
redirectMessage $ "OAuth2 handshake failure: " <> erUserMessage err
|
|
||||||
|
|
||||||
-- | Handle an unexpected error
|
|
||||||
--
|
|
||||||
-- This would be some unexpected exception while processing the callback.
|
|
||||||
-- Therefore, the user should see an opaque message and the details go only to
|
|
||||||
-- the server logs.
|
|
||||||
--
|
|
||||||
unexpectedError :: Show e => Text -> e -> AuthHandler m a
|
|
||||||
unexpectedError name err = do
|
|
||||||
$(logError) $ "Error in " <> name <> " OAuth2 plugin: " <> tshow err
|
|
||||||
redirectMessage "Unexpected error logging in with OAuth2"
|
|
||||||
|
|
||||||
redirectMessage :: Text -> AuthHandler m a
|
|
||||||
redirectMessage msg = do
|
|
||||||
toParent <- getRouteToParent
|
|
||||||
setMessage $ toHtml msg
|
|
||||||
redirect $ toParent LoginR
|
|
||||||
|
|
||||||
tryFetchCreds :: IO a -> IO (Either SomeException a)
|
|
||||||
tryFetchCreds f =
|
tryFetchCreds f =
|
||||||
(Right <$> f)
|
(Right <$> f)
|
||||||
`catch` (\(ex :: IOException) -> pure $ Left $ toException ex)
|
`catch` (pure . Left . FetchCredsIOException)
|
||||||
`catch` (\(ex :: YesodOAuth2Exception) -> pure $ Left $ toException ex)
|
`catch` (pure . Left . FetchCredsYesodOAuth2Exception)
|
||||||
|
|
||||||
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
|
withCallbackAndState
|
||||||
|
:: (MonadError DispatchError m, MonadAuthHandler site m)
|
||||||
|
=> Text
|
||||||
|
-> OAuth2
|
||||||
|
-> Text
|
||||||
|
-> m OAuth2
|
||||||
withCallbackAndState name oauth2 csrf = do
|
withCallbackAndState name oauth2 csrf = do
|
||||||
let url = PluginR name ["callback"]
|
let url = PluginR name ["callback"]
|
||||||
render <- getParentUrlRender
|
render <- getParentUrlRender
|
||||||
let callbackText = render url
|
let callbackText = render url
|
||||||
|
|
||||||
callback <-
|
callback <- maybe (throwError $ InvalidCallbackUri callbackText) pure
|
||||||
maybe
|
$ fromText callbackText
|
||||||
(liftIO
|
|
||||||
$ throwString
|
|
||||||
$ "Invalid callback URI: "
|
|
||||||
<> T.unpack callbackText
|
|
||||||
<> ". Not using an absolute Approot?"
|
|
||||||
)
|
|
||||||
pure
|
|
||||||
$ fromText callbackText
|
|
||||||
|
|
||||||
pure oauth2
|
pure oauth2
|
||||||
{ oauthCallback = Just callback
|
{ oauthCallback = Just callback
|
||||||
@ -169,40 +139,28 @@ setSessionCSRF :: MonadHandler m => Text -> m Text
|
|||||||
setSessionCSRF sessionKey = do
|
setSessionCSRF sessionKey = do
|
||||||
csrfToken <- liftIO randomToken
|
csrfToken <- liftIO randomToken
|
||||||
csrfToken <$ setSession sessionKey csrfToken
|
csrfToken <$ setSession sessionKey csrfToken
|
||||||
where
|
where randomToken = T.filter (/= '+') <$> randomText 64
|
||||||
randomToken =
|
|
||||||
T.filter (/= '+')
|
|
||||||
. decodeUtf8
|
|
||||||
. convertToBase @ByteString Base64
|
|
||||||
<$> getRandomBytes 64
|
|
||||||
|
|
||||||
-- | Verify the callback provided the same CSRF token as in our session
|
-- | Verify the callback provided the same CSRF token as in our session
|
||||||
verifySessionCSRF :: MonadHandler m => Text -> m Text
|
verifySessionCSRF
|
||||||
|
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
|
||||||
verifySessionCSRF sessionKey = do
|
verifySessionCSRF sessionKey = do
|
||||||
token <- requireGetParam "state"
|
token <- requireGetParam "state"
|
||||||
sessionToken <- lookupSession sessionKey
|
sessionToken <- lookupSession sessionKey
|
||||||
deleteSession sessionKey
|
deleteSession sessionKey
|
||||||
|
|
||||||
unless (sessionToken == Just token) $ do
|
unless (sessionToken == Just token) $ throwError $ InvalidStateToken
|
||||||
$(logError)
|
sessionToken
|
||||||
$ "state token does not match. "
|
token
|
||||||
<> "Param: "
|
|
||||||
<> tshow token
|
|
||||||
<> "State: "
|
|
||||||
<> tshow sessionToken
|
|
||||||
permissionDenied "Invalid OAuth2 state token"
|
|
||||||
|
|
||||||
return token
|
pure token
|
||||||
|
|
||||||
requireGetParam :: MonadHandler m => Text -> m Text
|
requireGetParam
|
||||||
|
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
|
||||||
requireGetParam key = do
|
requireGetParam key = do
|
||||||
m <- lookupGetParam key
|
m <- lookupGetParam key
|
||||||
maybe errInvalidArgs return m
|
maybe err return m
|
||||||
where
|
where err = throwError $ MissingParameter key
|
||||||
errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"]
|
|
||||||
|
|
||||||
tokenSessionKey :: Text -> Text
|
tokenSessionKey :: Text -> Text
|
||||||
tokenSessionKey name = "_yesod_oauth2_" <> name
|
tokenSessionKey name = "_yesod_oauth2_" <> name
|
||||||
|
|
||||||
tshow :: Show a => a -> Text
|
|
||||||
tshow = T.pack . show
|
|
||||||
|
|||||||
78
src/Yesod/Auth/OAuth2/DispatchError.hs
Normal file
78
src/Yesod/Auth/OAuth2/DispatchError.hs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
{-# LANGUAGE DeriveAnyClass #-}
|
||||||
|
{-# LANGUAGE DerivingStrategies #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
|
{-# LANGUAGE TypeApplications #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
|
module Yesod.Auth.OAuth2.DispatchError
|
||||||
|
( DispatchError(..)
|
||||||
|
, handleDispatchError
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Except
|
||||||
|
import Data.Text (Text, pack)
|
||||||
|
import Network.OAuth.OAuth2
|
||||||
|
import Network.OAuth.OAuth2.TokenRequest (Errors)
|
||||||
|
import UnliftIO.Except ()
|
||||||
|
import UnliftIO.Exception
|
||||||
|
import Yesod.Auth hiding (ServerError)
|
||||||
|
import Yesod.Auth.OAuth2.ErrorResponse
|
||||||
|
import Yesod.Auth.OAuth2.Exception
|
||||||
|
import Yesod.Auth.OAuth2.Random
|
||||||
|
import Yesod.Core hiding (ErrorResponse)
|
||||||
|
|
||||||
|
data DispatchError
|
||||||
|
= MissingParameter Text
|
||||||
|
| InvalidStateToken (Maybe Text) Text
|
||||||
|
| InvalidCallbackUri Text
|
||||||
|
| OAuth2HandshakeError ErrorResponse
|
||||||
|
| OAuth2ResultError (OAuth2Error Errors)
|
||||||
|
| FetchCredsIOException IOException
|
||||||
|
| FetchCredsYesodOAuth2Exception YesodOAuth2Exception
|
||||||
|
deriving stock Show
|
||||||
|
deriving anyclass Exception
|
||||||
|
|
||||||
|
-- | User-friendly message for any given 'DispatchError'
|
||||||
|
--
|
||||||
|
-- Most of these are opaque to the user. The exception details are present for
|
||||||
|
-- the server logs.
|
||||||
|
--
|
||||||
|
dispatchErrorMessage :: DispatchError -> Text
|
||||||
|
dispatchErrorMessage = \case
|
||||||
|
MissingParameter name ->
|
||||||
|
"Parameter '" <> name <> "' is required, but not present in the URL"
|
||||||
|
InvalidStateToken{} -> "State token is invalid, please try again"
|
||||||
|
InvalidCallbackUri{}
|
||||||
|
-> "Callback URI was not valid, this server may be misconfigured (no approot)"
|
||||||
|
OAuth2HandshakeError er -> "OAuth2 handshake failure: " <> erUserMessage er
|
||||||
|
OAuth2ResultError{} -> "Login failed, please try again"
|
||||||
|
FetchCredsIOException{} -> "Login failed, please try again"
|
||||||
|
FetchCredsYesodOAuth2Exception{} -> "Login failed, please try again"
|
||||||
|
|
||||||
|
handleDispatchError
|
||||||
|
:: MonadAuthHandler site m
|
||||||
|
=> ExceptT DispatchError m TypedContent
|
||||||
|
-> m TypedContent
|
||||||
|
handleDispatchError f = do
|
||||||
|
result <- runExceptT f
|
||||||
|
either onDispatchError pure result
|
||||||
|
|
||||||
|
onDispatchError :: MonadAuthHandler site m => DispatchError -> m TypedContent
|
||||||
|
onDispatchError err = do
|
||||||
|
errorId <- liftIO $ randomText 16
|
||||||
|
let suffix = " [errorId=" <> errorId <> "]"
|
||||||
|
$(logError) $ pack (displayException err) <> suffix
|
||||||
|
|
||||||
|
let message = dispatchErrorMessage err <> suffix
|
||||||
|
messageValue =
|
||||||
|
object ["error" .= object ["id" .= errorId, "message" .= message]]
|
||||||
|
|
||||||
|
loginR <- ($ LoginR) <$> getRouteToParent
|
||||||
|
|
||||||
|
selectRep $ do
|
||||||
|
provideRep @_ @Html $ onErrorHtml loginR message
|
||||||
|
provideRep @_ @Value $ pure messageValue
|
||||||
19
src/Yesod/Auth/OAuth2/Random.hs
Normal file
19
src/Yesod/Auth/OAuth2/Random.hs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
{-# LANGUAGE TypeApplications #-}
|
||||||
|
|
||||||
|
module Yesod.Auth.OAuth2.Random
|
||||||
|
( randomText
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Crypto.Random (MonadRandom, getRandomBytes)
|
||||||
|
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Data.Text.Encoding (decodeUtf8)
|
||||||
|
|
||||||
|
randomText
|
||||||
|
:: MonadRandom m
|
||||||
|
=> Int
|
||||||
|
-- ^ Size in Bytes (note necessarily characters)
|
||||||
|
-> m Text
|
||||||
|
randomText size =
|
||||||
|
decodeUtf8 . convertToBase @ByteString Base64 <$> getRandomBytes size
|
||||||
Loading…
Reference in New Issue
Block a user