Consolidate all errors, use onErrorHtml

Prior to this commit, some errors would be thrown (missing parameter,
invalid state, incorrect approot) while others would be handled via the
set-message-redirect approach (handshake failure, fetch-token failure,
etc).

This commit consolidates all of these cases into a single DispatchError
type, and then uses MonadError (concretely ExceptT) to capture them all
and handle them in one place ourselves.

It then updates that handling to:

- Use onErrorHtml

  onErrorHtml will, by default, set-message-redirect. That make this
  behavior neutral for users running defaults. For users that have
  customized this, it will be an improvement that all our error cases
  now respect it.

- Provided a JSON representation of errors
- Attach a random correlation identifier

The last two were just nice-to-haves that were cheap to add once the
code was in this state.

Note that the use of MonadError requires a potentially "bad" orphan
MonadUnliftIO instance for ExceptT, but I'd like to see that instance
become a reality and think it needs some real-world experimentation to
get there, so here I am.
This commit is contained in:
patrick brisbin 2021-02-26 14:07:23 -05:00
parent cfcd8c5210
commit 8b3908ec91
No known key found for this signature in database
GPG Key ID: 20299C6982D938FB
5 changed files with 159 additions and 90 deletions

View File

@ -32,11 +32,13 @@ library:
- http-types >=0.8 && <0.13 - http-types >=0.8 && <0.13
- memory - memory
- microlens - microlens
- mtl
- safe-exceptions - safe-exceptions
- text >=0.7 && <2.0 - text >=0.7 && <2.0
- uri-bytestring - uri-bytestring
- yesod-auth >=1.6.0 && <1.7 - yesod-auth >=1.6.0 && <1.7
- yesod-core >=1.6.0 && <1.7 - yesod-core >=1.6.0 && <1.7
- unliftio
executables: executables:
yesod-auth-oauth2-example: yesod-auth-oauth2-example:

12
src/UnliftIO/Except.hs Normal file
View File

@ -0,0 +1,12 @@
{-# OPTIONS_GHC -Wno-orphans #-}
module UnliftIO.Except
() where
import Control.Monad.Except
import UnliftIO
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
withRunInIO exceptToIO = ExceptT $ try $ do
withRunInIO $ \runInIO ->
exceptToIO (runInIO . (either throwIO pure <=< runExceptT))

View File

@ -1,36 +1,30 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.Dispatch module Yesod.Auth.OAuth2.Dispatch
( FetchToken ( FetchToken
, fetchAccessToken , fetchAccessToken
, fetchAccessToken2 , fetchAccessToken2
, FetchCreds , FetchCreds
, dispatchAuthRequest , dispatchAuthRequest
) ) where
where
import Control.Exception.Safe import Control.Monad.Except
import Control.Monad (unless, (<=<))
import Crypto.Random (getRandomBytes)
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
import Data.ByteString (ByteString)
import Data.Text (Text) import Data.Text (Text)
import qualified Data.Text as T import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager) import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2 import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors) import Network.OAuth.OAuth2.TokenRequest (Errors)
import UnliftIO.Exception
import URI.ByteString.Extension import URI.ByteString.Extension
import Yesod.Auth hiding (ServerError) import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.DispatchError
import Yesod.Auth.OAuth2.ErrorResponse import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Exception import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse) import Yesod.Core hiding (ErrorResponse)
-- | How to fetch an @'OAuth2Token'@ -- | How to fetch an @'OAuth2Token'@
@ -53,9 +47,9 @@ dispatchAuthRequest
-> [Text] -- ^ Path pieces -> [Text] -- ^ Path pieces
-> AuthHandler m TypedContent -> AuthHandler m TypedContent
dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] = dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] =
dispatchForward name oauth2 handleDispatchError $ dispatchForward name oauth2
dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] = dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] =
dispatchCallback name oauth2 getToken getCreds handleDispatchError $ dispatchCallback name oauth2 getToken getCreds
dispatchAuthRequest _ _ _ _ _ _ = notFound dispatchAuthRequest _ _ _ _ _ _ = notFound
-- | Handle @GET \/forward@ -- | Handle @GET \/forward@
@ -63,7 +57,11 @@ dispatchAuthRequest _ _ _ _ _ _ = notFound
-- 1. Set a random CSRF token in our session -- 1. Set a random CSRF token in our session
-- 2. Redirect to the Provider's authorization URL -- 2. Redirect to the Provider's authorization URL
-- --
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent dispatchForward
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> m TypedContent
dispatchForward name oauth2 = do dispatchForward name oauth2 = do
csrf <- setSessionCSRF $ tokenSessionKey name csrf <- setSessionCSRF $ tokenSessionKey name
oauth2' <- withCallbackAndState name oauth2 csrf oauth2' <- withCallbackAndState name oauth2 csrf
@ -76,75 +74,47 @@ dispatchForward name oauth2 = do
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider -- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
-- --
dispatchCallback dispatchCallback
:: Text :: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2 -> OAuth2
-> FetchToken -> FetchToken
-> FetchCreds m -> FetchCreds site
-> AuthHandler m TypedContent -> m TypedContent
dispatchCallback name oauth2 getToken getCreds = do dispatchCallback name oauth2 getToken getCreds = do
csrf <- verifySessionCSRF $ tokenSessionKey name csrf <- verifySessionCSRF $ tokenSessionKey name
onErrorResponse $ oauth2HandshakeError name onErrorResponse $ throwError . OAuth2HandshakeError
code <- requireGetParam "code" code <- requireGetParam "code"
manager <- authHttpManager manager <- authHttpManager
oauth2' <- withCallbackAndState name oauth2 csrf oauth2' <- withCallbackAndState name oauth2 csrf
token <- errLeft $ getToken manager oauth2' $ ExchangeToken code token <-
creds <- errLeft $ tryFetchCreds $ getCreds manager token errLeft OAuth2ResultError $ getToken manager oauth2' $ ExchangeToken
code
creds <- errLeft id $ tryFetchCreds $ getCreds manager token
setCredsRedirect creds setCredsRedirect creds
where where
errLeft :: Show e => IO (Either e a) -> AuthHandler m a errLeft
errLeft = either (unexpectedError name) pure <=< liftIO :: (MonadIO m, MonadError e m) => (e' -> e) -> IO (Either e' a) -> m a
errLeft f = either (throwError . f) pure <=< liftIO
-- | Handle an OAuth2 @'ErrorResponse'@ tryFetchCreds :: IO a -> IO (Either DispatchError a)
--
-- These are things coming from the OAuth2 provider such an Invalid Grant or
-- Invalid Scope and /may/ be user-actionable. We've coded them to have an
-- @'erUserMessage'@ that we are comfortable displaying to the user as part of
-- the redirect, just in case.
--
oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a
oauth2HandshakeError name err = do
$(logError) $ "Handshake failure in " <> name <> " plugin: " <> tshow err
redirectMessage $ "OAuth2 handshake failure: " <> erUserMessage err
-- | Handle an unexpected error
--
-- This would be some unexpected exception while processing the callback.
-- Therefore, the user should see an opaque message and the details go only to
-- the server logs.
--
unexpectedError :: Show e => Text -> e -> AuthHandler m a
unexpectedError name err = do
$(logError) $ "Error in " <> name <> " OAuth2 plugin: " <> tshow err
redirectMessage "Unexpected error logging in with OAuth2"
redirectMessage :: Text -> AuthHandler m a
redirectMessage msg = do
toParent <- getRouteToParent
setMessage $ toHtml msg
redirect $ toParent LoginR
tryFetchCreds :: IO a -> IO (Either SomeException a)
tryFetchCreds f = tryFetchCreds f =
(Right <$> f) (Right <$> f)
`catch` (\(ex :: IOException) -> pure $ Left $ toException ex) `catch` (pure . Left . FetchCredsIOException)
`catch` (\(ex :: YesodOAuth2Exception) -> pure $ Left $ toException ex) `catch` (pure . Left . FetchCredsYesodOAuth2Exception)
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2 withCallbackAndState
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> Text
-> m OAuth2
withCallbackAndState name oauth2 csrf = do withCallbackAndState name oauth2 csrf = do
let url = PluginR name ["callback"] let url = PluginR name ["callback"]
render <- getParentUrlRender render <- getParentUrlRender
let callbackText = render url let callbackText = render url
callback <- callback <- maybe (throwError $ InvalidCallbackUri callbackText) pure
maybe $ fromText callbackText
(liftIO
$ throwString
$ "Invalid callback URI: "
<> T.unpack callbackText
<> ". Not using an absolute Approot?"
)
pure
$ fromText callbackText
pure oauth2 pure oauth2
{ oauthCallback = Just callback { oauthCallback = Just callback
@ -169,40 +139,28 @@ setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF sessionKey = do setSessionCSRF sessionKey = do
csrfToken <- liftIO randomToken csrfToken <- liftIO randomToken
csrfToken <$ setSession sessionKey csrfToken csrfToken <$ setSession sessionKey csrfToken
where where randomToken = T.filter (/= '+') <$> randomText 64
randomToken =
T.filter (/= '+')
. decodeUtf8
. convertToBase @ByteString Base64
<$> getRandomBytes 64
-- | Verify the callback provided the same CSRF token as in our session -- | Verify the callback provided the same CSRF token as in our session
verifySessionCSRF :: MonadHandler m => Text -> m Text verifySessionCSRF
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
verifySessionCSRF sessionKey = do verifySessionCSRF sessionKey = do
token <- requireGetParam "state" token <- requireGetParam "state"
sessionToken <- lookupSession sessionKey sessionToken <- lookupSession sessionKey
deleteSession sessionKey deleteSession sessionKey
unless (sessionToken == Just token) $ do unless (sessionToken == Just token) $ throwError $ InvalidStateToken
$(logError) sessionToken
$ "state token does not match. " token
<> "Param: "
<> tshow token
<> "State: "
<> tshow sessionToken
permissionDenied "Invalid OAuth2 state token"
return token pure token
requireGetParam :: MonadHandler m => Text -> m Text requireGetParam
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
requireGetParam key = do requireGetParam key = do
m <- lookupGetParam key m <- lookupGetParam key
maybe errInvalidArgs return m maybe err return m
where where err = throwError $ MissingParameter key
errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"]
tokenSessionKey :: Text -> Text tokenSessionKey :: Text -> Text
tokenSessionKey name = "_yesod_oauth2_" <> name tokenSessionKey name = "_yesod_oauth2_" <> name
tshow :: Show a => a -> Text
tshow = T.pack . show

View File

@ -0,0 +1,78 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.DispatchError
( DispatchError(..)
, handleDispatchError
) where
import Control.Monad.Except
import Data.Text (Text, pack)
import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors)
import UnliftIO.Except ()
import UnliftIO.Exception
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Exception
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)
data DispatchError
= MissingParameter Text
| InvalidStateToken (Maybe Text) Text
| InvalidCallbackUri Text
| OAuth2HandshakeError ErrorResponse
| OAuth2ResultError (OAuth2Error Errors)
| FetchCredsIOException IOException
| FetchCredsYesodOAuth2Exception YesodOAuth2Exception
deriving stock Show
deriving anyclass Exception
-- | User-friendly message for any given 'DispatchError'
--
-- Most of these are opaque to the user. The exception details are present for
-- the server logs.
--
dispatchErrorMessage :: DispatchError -> Text
dispatchErrorMessage = \case
MissingParameter name ->
"Parameter '" <> name <> "' is required, but not present in the URL"
InvalidStateToken{} -> "State token is invalid, please try again"
InvalidCallbackUri{}
-> "Callback URI was not valid, this server may be misconfigured (no approot)"
OAuth2HandshakeError er -> "OAuth2 handshake failure: " <> erUserMessage er
OAuth2ResultError{} -> "Login failed, please try again"
FetchCredsIOException{} -> "Login failed, please try again"
FetchCredsYesodOAuth2Exception{} -> "Login failed, please try again"
handleDispatchError
:: MonadAuthHandler site m
=> ExceptT DispatchError m TypedContent
-> m TypedContent
handleDispatchError f = do
result <- runExceptT f
either onDispatchError pure result
onDispatchError :: MonadAuthHandler site m => DispatchError -> m TypedContent
onDispatchError err = do
errorId <- liftIO $ randomText 16
let suffix = " [errorId=" <> errorId <> "]"
$(logError) $ pack (displayException err) <> suffix
let message = dispatchErrorMessage err <> suffix
messageValue =
object ["error" .= object ["id" .= errorId, "message" .= message]]
loginR <- ($ LoginR) <$> getRouteToParent
selectRep $ do
provideRep @_ @Html $ onErrorHtml loginR message
provideRep @_ @Value $ pure messageValue

View File

@ -0,0 +1,19 @@
{-# LANGUAGE TypeApplications #-}
module Yesod.Auth.OAuth2.Random
( randomText
) where
import Crypto.Random (MonadRandom, getRandomBytes)
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
import Data.ByteString (ByteString)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
randomText
:: MonadRandom m
=> Int
-- ^ Size in Bytes (note necessarily characters)
-> m Text
randomText size =
decodeUtf8 . convertToBase @ByteString Base64 <$> getRandomBytes size