mirror of
https://github.com/freckle/yesod-auth-oauth2.git
synced 2026-01-11 19:58:28 +01:00
Consolidate all errors, use onErrorHtml
Prior to this commit, some errors would be thrown (missing parameter, invalid state, incorrect approot) while others would be handled via the set-message-redirect approach (handshake failure, fetch-token failure, etc). This commit consolidates all of these cases into a single DispatchError type, and then uses MonadError (concretely ExceptT) to capture them all and handle them in one place ourselves. It then updates that handling to: - Use onErrorHtml onErrorHtml will, by default, set-message-redirect. That make this behavior neutral for users running defaults. For users that have customized this, it will be an improvement that all our error cases now respect it. - Provided a JSON representation of errors - Attach a random correlation identifier The last two were just nice-to-haves that were cheap to add once the code was in this state. Note that the use of MonadError requires a potentially "bad" orphan MonadUnliftIO instance for ExceptT, but I'd like to see that instance become a reality and think it needs some real-world experimentation to get there, so here I am.
This commit is contained in:
parent
16aad54338
commit
ab17f214eb
@ -35,11 +35,13 @@ library:
|
||||
- http-types >=0.8 && <0.13
|
||||
- memory
|
||||
- microlens
|
||||
- mtl
|
||||
- safe-exceptions
|
||||
- text >=0.7 && <2.0
|
||||
- uri-bytestring
|
||||
- yesod-auth >=1.6.0 && <1.7
|
||||
- yesod-core >=1.6.0 && <1.7
|
||||
- unliftio
|
||||
|
||||
executables:
|
||||
yesod-auth-oauth2-example:
|
||||
|
||||
12
src/UnliftIO/Except.hs
Normal file
12
src/UnliftIO/Except.hs
Normal file
@ -0,0 +1,12 @@
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module UnliftIO.Except
|
||||
() where
|
||||
|
||||
import Control.Monad.Except
|
||||
import UnliftIO
|
||||
|
||||
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||
withRunInIO exceptToIO = ExceptT $ try $ do
|
||||
withRunInIO $ \runInIO ->
|
||||
exceptToIO (runInIO . (either throwIO pure <=< runExceptT))
|
||||
@ -1,36 +1,30 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
module Yesod.Auth.OAuth2.Dispatch
|
||||
( FetchToken
|
||||
, fetchAccessToken
|
||||
, fetchAccessToken2
|
||||
, FetchCreds
|
||||
, dispatchAuthRequest
|
||||
)
|
||||
where
|
||||
) where
|
||||
|
||||
import Control.Exception.Safe
|
||||
import Control.Monad (unless, (<=<))
|
||||
import Crypto.Random (getRandomBytes)
|
||||
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
|
||||
import Data.ByteString (ByteString)
|
||||
import Control.Monad.Except
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Network.HTTP.Conduit (Manager)
|
||||
import Network.OAuth.OAuth2
|
||||
import Network.OAuth.OAuth2.TokenRequest (Errors)
|
||||
import UnliftIO.Exception
|
||||
import URI.ByteString.Extension
|
||||
import Yesod.Auth hiding (ServerError)
|
||||
import Yesod.Auth.OAuth2.DispatchError
|
||||
import Yesod.Auth.OAuth2.ErrorResponse
|
||||
import Yesod.Auth.OAuth2.Exception
|
||||
import Yesod.Auth.OAuth2.Random
|
||||
import Yesod.Core hiding (ErrorResponse)
|
||||
|
||||
-- | How to fetch an @'OAuth2Token'@
|
||||
@ -53,9 +47,9 @@ dispatchAuthRequest
|
||||
-> [Text] -- ^ Path pieces
|
||||
-> AuthHandler m TypedContent
|
||||
dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] =
|
||||
dispatchForward name oauth2
|
||||
handleDispatchError $ dispatchForward name oauth2
|
||||
dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] =
|
||||
dispatchCallback name oauth2 getToken getCreds
|
||||
handleDispatchError $ dispatchCallback name oauth2 getToken getCreds
|
||||
dispatchAuthRequest _ _ _ _ _ _ = notFound
|
||||
|
||||
-- | Handle @GET \/forward@
|
||||
@ -63,7 +57,11 @@ dispatchAuthRequest _ _ _ _ _ _ = notFound
|
||||
-- 1. Set a random CSRF token in our session
|
||||
-- 2. Redirect to the Provider's authorization URL
|
||||
--
|
||||
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
|
||||
dispatchForward
|
||||
:: (MonadError DispatchError m, MonadAuthHandler site m)
|
||||
=> Text
|
||||
-> OAuth2
|
||||
-> m TypedContent
|
||||
dispatchForward name oauth2 = do
|
||||
csrf <- setSessionCSRF $ tokenSessionKey name
|
||||
oauth2' <- withCallbackAndState name oauth2 csrf
|
||||
@ -76,75 +74,47 @@ dispatchForward name oauth2 = do
|
||||
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
|
||||
--
|
||||
dispatchCallback
|
||||
:: Text
|
||||
:: (MonadError DispatchError m, MonadAuthHandler site m)
|
||||
=> Text
|
||||
-> OAuth2
|
||||
-> FetchToken
|
||||
-> FetchCreds m
|
||||
-> AuthHandler m TypedContent
|
||||
-> FetchCreds site
|
||||
-> m TypedContent
|
||||
dispatchCallback name oauth2 getToken getCreds = do
|
||||
csrf <- verifySessionCSRF $ tokenSessionKey name
|
||||
onErrorResponse $ oauth2HandshakeError name
|
||||
onErrorResponse $ throwError . OAuth2HandshakeError
|
||||
code <- requireGetParam "code"
|
||||
manager <- authHttpManager
|
||||
oauth2' <- withCallbackAndState name oauth2 csrf
|
||||
token <- errLeft $ getToken manager oauth2' $ ExchangeToken code
|
||||
creds <- errLeft $ tryFetchCreds $ getCreds manager token
|
||||
token <-
|
||||
errLeft OAuth2ResultError $ getToken manager oauth2' $ ExchangeToken
|
||||
code
|
||||
creds <- errLeft id $ tryFetchCreds $ getCreds manager token
|
||||
setCredsRedirect creds
|
||||
where
|
||||
errLeft :: Show e => IO (Either e a) -> AuthHandler m a
|
||||
errLeft = either (unexpectedError name) pure <=< liftIO
|
||||
errLeft
|
||||
:: (MonadIO m, MonadError e m) => (e' -> e) -> IO (Either e' a) -> m a
|
||||
errLeft f = either (throwError . f) pure <=< liftIO
|
||||
|
||||
-- | Handle an OAuth2 @'ErrorResponse'@
|
||||
--
|
||||
-- These are things coming from the OAuth2 provider such an Invalid Grant or
|
||||
-- Invalid Scope and /may/ be user-actionable. We've coded them to have an
|
||||
-- @'erUserMessage'@ that we are comfortable displaying to the user as part of
|
||||
-- the redirect, just in case.
|
||||
--
|
||||
oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a
|
||||
oauth2HandshakeError name err = do
|
||||
$(logError) $ "Handshake failure in " <> name <> " plugin: " <> tshow err
|
||||
redirectMessage $ "OAuth2 handshake failure: " <> erUserMessage err
|
||||
|
||||
-- | Handle an unexpected error
|
||||
--
|
||||
-- This would be some unexpected exception while processing the callback.
|
||||
-- Therefore, the user should see an opaque message and the details go only to
|
||||
-- the server logs.
|
||||
--
|
||||
unexpectedError :: Show e => Text -> e -> AuthHandler m a
|
||||
unexpectedError name err = do
|
||||
$(logError) $ "Error in " <> name <> " OAuth2 plugin: " <> tshow err
|
||||
redirectMessage "Unexpected error logging in with OAuth2"
|
||||
|
||||
redirectMessage :: Text -> AuthHandler m a
|
||||
redirectMessage msg = do
|
||||
toParent <- getRouteToParent
|
||||
setMessage $ toHtml msg
|
||||
redirect $ toParent LoginR
|
||||
|
||||
tryFetchCreds :: IO a -> IO (Either SomeException a)
|
||||
tryFetchCreds :: IO a -> IO (Either DispatchError a)
|
||||
tryFetchCreds f =
|
||||
(Right <$> f)
|
||||
`catch` (\(ex :: IOException) -> pure $ Left $ toException ex)
|
||||
`catch` (\(ex :: YesodOAuth2Exception) -> pure $ Left $ toException ex)
|
||||
`catch` (pure . Left . FetchCredsIOException)
|
||||
`catch` (pure . Left . FetchCredsYesodOAuth2Exception)
|
||||
|
||||
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
|
||||
withCallbackAndState
|
||||
:: (MonadError DispatchError m, MonadAuthHandler site m)
|
||||
=> Text
|
||||
-> OAuth2
|
||||
-> Text
|
||||
-> m OAuth2
|
||||
withCallbackAndState name oauth2 csrf = do
|
||||
let url = PluginR name ["callback"]
|
||||
render <- getParentUrlRender
|
||||
let callbackText = render url
|
||||
|
||||
callback <-
|
||||
maybe
|
||||
(liftIO
|
||||
$ throwString
|
||||
$ "Invalid callback URI: "
|
||||
<> T.unpack callbackText
|
||||
<> ". Not using an absolute Approot?"
|
||||
)
|
||||
pure
|
||||
$ fromText callbackText
|
||||
callback <- maybe (throwError $ InvalidCallbackUri callbackText) pure
|
||||
$ fromText callbackText
|
||||
|
||||
pure oauth2
|
||||
{ oauthCallback = Just callback
|
||||
@ -169,40 +139,28 @@ setSessionCSRF :: MonadHandler m => Text -> m Text
|
||||
setSessionCSRF sessionKey = do
|
||||
csrfToken <- liftIO randomToken
|
||||
csrfToken <$ setSession sessionKey csrfToken
|
||||
where
|
||||
randomToken =
|
||||
T.filter (/= '+')
|
||||
. decodeUtf8
|
||||
. convertToBase @ByteString Base64
|
||||
<$> getRandomBytes 64
|
||||
where randomToken = T.filter (/= '+') <$> randomText 64
|
||||
|
||||
-- | Verify the callback provided the same CSRF token as in our session
|
||||
verifySessionCSRF :: MonadHandler m => Text -> m Text
|
||||
verifySessionCSRF
|
||||
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
|
||||
verifySessionCSRF sessionKey = do
|
||||
token <- requireGetParam "state"
|
||||
sessionToken <- lookupSession sessionKey
|
||||
deleteSession sessionKey
|
||||
|
||||
unless (sessionToken == Just token) $ do
|
||||
$(logError)
|
||||
$ "state token does not match. "
|
||||
<> "Param: "
|
||||
<> tshow token
|
||||
<> "State: "
|
||||
<> tshow sessionToken
|
||||
permissionDenied "Invalid OAuth2 state token"
|
||||
unless (sessionToken == Just token) $ throwError $ InvalidStateToken
|
||||
sessionToken
|
||||
token
|
||||
|
||||
return token
|
||||
pure token
|
||||
|
||||
requireGetParam :: MonadHandler m => Text -> m Text
|
||||
requireGetParam
|
||||
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
|
||||
requireGetParam key = do
|
||||
m <- lookupGetParam key
|
||||
maybe errInvalidArgs return m
|
||||
where
|
||||
errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"]
|
||||
maybe err return m
|
||||
where err = throwError $ MissingParameter key
|
||||
|
||||
tokenSessionKey :: Text -> Text
|
||||
tokenSessionKey name = "_yesod_oauth2_" <> name
|
||||
|
||||
tshow :: Show a => a -> Text
|
||||
tshow = T.pack . show
|
||||
|
||||
78
src/Yesod/Auth/OAuth2/DispatchError.hs
Normal file
78
src/Yesod/Auth/OAuth2/DispatchError.hs
Normal file
@ -0,0 +1,78 @@
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DerivingStrategies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
module Yesod.Auth.OAuth2.DispatchError
|
||||
( DispatchError(..)
|
||||
, handleDispatchError
|
||||
) where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Data.Text (Text, pack)
|
||||
import Network.OAuth.OAuth2
|
||||
import Network.OAuth.OAuth2.TokenRequest (Errors)
|
||||
import UnliftIO.Except ()
|
||||
import UnliftIO.Exception
|
||||
import Yesod.Auth hiding (ServerError)
|
||||
import Yesod.Auth.OAuth2.ErrorResponse
|
||||
import Yesod.Auth.OAuth2.Exception
|
||||
import Yesod.Auth.OAuth2.Random
|
||||
import Yesod.Core hiding (ErrorResponse)
|
||||
|
||||
data DispatchError
|
||||
= MissingParameter Text
|
||||
| InvalidStateToken (Maybe Text) Text
|
||||
| InvalidCallbackUri Text
|
||||
| OAuth2HandshakeError ErrorResponse
|
||||
| OAuth2ResultError (OAuth2Error Errors)
|
||||
| FetchCredsIOException IOException
|
||||
| FetchCredsYesodOAuth2Exception YesodOAuth2Exception
|
||||
deriving stock Show
|
||||
deriving anyclass Exception
|
||||
|
||||
-- | User-friendly message for any given 'DispatchError'
|
||||
--
|
||||
-- Most of these are opaque to the user. The exception details are present for
|
||||
-- the server logs.
|
||||
--
|
||||
dispatchErrorMessage :: DispatchError -> Text
|
||||
dispatchErrorMessage = \case
|
||||
MissingParameter name ->
|
||||
"Parameter '" <> name <> "' is required, but not present in the URL"
|
||||
InvalidStateToken{} -> "State token is invalid, please try again"
|
||||
InvalidCallbackUri{}
|
||||
-> "Callback URI was not valid, this server may be misconfigured (no approot)"
|
||||
OAuth2HandshakeError er -> "OAuth2 handshake failure: " <> erUserMessage er
|
||||
OAuth2ResultError{} -> "Login failed, please try again"
|
||||
FetchCredsIOException{} -> "Login failed, please try again"
|
||||
FetchCredsYesodOAuth2Exception{} -> "Login failed, please try again"
|
||||
|
||||
handleDispatchError
|
||||
:: MonadAuthHandler site m
|
||||
=> ExceptT DispatchError m TypedContent
|
||||
-> m TypedContent
|
||||
handleDispatchError f = do
|
||||
result <- runExceptT f
|
||||
either onDispatchError pure result
|
||||
|
||||
onDispatchError :: MonadAuthHandler site m => DispatchError -> m TypedContent
|
||||
onDispatchError err = do
|
||||
errorId <- liftIO $ randomText 16
|
||||
let suffix = " [errorId=" <> errorId <> "]"
|
||||
$(logError) $ pack (displayException err) <> suffix
|
||||
|
||||
let message = dispatchErrorMessage err <> suffix
|
||||
messageValue =
|
||||
object ["error" .= object ["id" .= errorId, "message" .= message]]
|
||||
|
||||
loginR <- ($ LoginR) <$> getRouteToParent
|
||||
|
||||
selectRep $ do
|
||||
provideRep @_ @Html $ onErrorHtml loginR message
|
||||
provideRep @_ @Value $ pure messageValue
|
||||
19
src/Yesod/Auth/OAuth2/Random.hs
Normal file
19
src/Yesod/Auth/OAuth2/Random.hs
Normal file
@ -0,0 +1,19 @@
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Yesod.Auth.OAuth2.Random
|
||||
( randomText
|
||||
) where
|
||||
|
||||
import Crypto.Random (MonadRandom, getRandomBytes)
|
||||
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Text (Text)
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
|
||||
randomText
|
||||
:: MonadRandom m
|
||||
=> Int
|
||||
-- ^ Size in Bytes (note necessarily characters)
|
||||
-> m Text
|
||||
randomText size =
|
||||
decodeUtf8 . convertToBase @ByteString Base64 <$> getRandomBytes size
|
||||
Loading…
Reference in New Issue
Block a user