Consolidate all errors, use onErrorHtml

Prior to this commit, some errors would be thrown (missing parameter,
invalid state, incorrect approot) while others would be handled via the
set-message-redirect approach (handshake failure, fetch-token failure,
etc).

This commit consolidates all of these cases into a single DispatchError
type, and then uses MonadError (concretely ExceptT) to capture them all
and handle them in one place ourselves.

It then updates that handling to:

- Use onErrorHtml

  onErrorHtml will, by default, set-message-redirect. That make this
  behavior neutral for users running defaults. For users that have
  customized this, it will be an improvement that all our error cases
  now respect it.

- Provided a JSON representation of errors
- Attach a random correlation identifier

The last two were just nice-to-haves that were cheap to add once the
code was in this state.

Note that the use of MonadError requires a potentially "bad" orphan
MonadUnliftIO instance for ExceptT, but I'd like to see that instance
become a reality and think it needs some real-world experimentation to
get there, so here I am.
This commit is contained in:
patrick brisbin 2021-02-26 14:07:23 -05:00
parent 16aad54338
commit ab17f214eb
5 changed files with 159 additions and 90 deletions

View File

@ -35,11 +35,13 @@ library:
- http-types >=0.8 && <0.13
- memory
- microlens
- mtl
- safe-exceptions
- text >=0.7 && <2.0
- uri-bytestring
- yesod-auth >=1.6.0 && <1.7
- yesod-core >=1.6.0 && <1.7
- unliftio
executables:
yesod-auth-oauth2-example:

12
src/UnliftIO/Except.hs Normal file
View File

@ -0,0 +1,12 @@
{-# OPTIONS_GHC -Wno-orphans #-}
module UnliftIO.Except
() where
import Control.Monad.Except
import UnliftIO
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
withRunInIO exceptToIO = ExceptT $ try $ do
withRunInIO $ \runInIO ->
exceptToIO (runInIO . (either throwIO pure <=< runExceptT))

View File

@ -1,36 +1,30 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.Dispatch
( FetchToken
, fetchAccessToken
, fetchAccessToken2
, FetchCreds
, dispatchAuthRequest
)
where
) where
import Control.Exception.Safe
import Control.Monad (unless, (<=<))
import Crypto.Random (getRandomBytes)
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
import Data.ByteString (ByteString)
import Control.Monad.Except
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors)
import UnliftIO.Exception
import URI.ByteString.Extension
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.DispatchError
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Exception
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)
-- | How to fetch an @'OAuth2Token'@
@ -53,9 +47,9 @@ dispatchAuthRequest
-> [Text] -- ^ Path pieces
-> AuthHandler m TypedContent
dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] =
dispatchForward name oauth2
handleDispatchError $ dispatchForward name oauth2
dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] =
dispatchCallback name oauth2 getToken getCreds
handleDispatchError $ dispatchCallback name oauth2 getToken getCreds
dispatchAuthRequest _ _ _ _ _ _ = notFound
-- | Handle @GET \/forward@
@ -63,7 +57,11 @@ dispatchAuthRequest _ _ _ _ _ _ = notFound
-- 1. Set a random CSRF token in our session
-- 2. Redirect to the Provider's authorization URL
--
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
dispatchForward
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> m TypedContent
dispatchForward name oauth2 = do
csrf <- setSessionCSRF $ tokenSessionKey name
oauth2' <- withCallbackAndState name oauth2 csrf
@ -76,75 +74,47 @@ dispatchForward name oauth2 = do
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
--
dispatchCallback
:: Text
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> AuthHandler m TypedContent
-> FetchCreds site
-> m TypedContent
dispatchCallback name oauth2 getToken getCreds = do
csrf <- verifySessionCSRF $ tokenSessionKey name
onErrorResponse $ oauth2HandshakeError name
onErrorResponse $ throwError . OAuth2HandshakeError
code <- requireGetParam "code"
manager <- authHttpManager
oauth2' <- withCallbackAndState name oauth2 csrf
token <- errLeft $ getToken manager oauth2' $ ExchangeToken code
creds <- errLeft $ tryFetchCreds $ getCreds manager token
token <-
errLeft OAuth2ResultError $ getToken manager oauth2' $ ExchangeToken
code
creds <- errLeft id $ tryFetchCreds $ getCreds manager token
setCredsRedirect creds
where
errLeft :: Show e => IO (Either e a) -> AuthHandler m a
errLeft = either (unexpectedError name) pure <=< liftIO
errLeft
:: (MonadIO m, MonadError e m) => (e' -> e) -> IO (Either e' a) -> m a
errLeft f = either (throwError . f) pure <=< liftIO
-- | Handle an OAuth2 @'ErrorResponse'@
--
-- These are things coming from the OAuth2 provider such an Invalid Grant or
-- Invalid Scope and /may/ be user-actionable. We've coded them to have an
-- @'erUserMessage'@ that we are comfortable displaying to the user as part of
-- the redirect, just in case.
--
oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a
oauth2HandshakeError name err = do
$(logError) $ "Handshake failure in " <> name <> " plugin: " <> tshow err
redirectMessage $ "OAuth2 handshake failure: " <> erUserMessage err
-- | Handle an unexpected error
--
-- This would be some unexpected exception while processing the callback.
-- Therefore, the user should see an opaque message and the details go only to
-- the server logs.
--
unexpectedError :: Show e => Text -> e -> AuthHandler m a
unexpectedError name err = do
$(logError) $ "Error in " <> name <> " OAuth2 plugin: " <> tshow err
redirectMessage "Unexpected error logging in with OAuth2"
redirectMessage :: Text -> AuthHandler m a
redirectMessage msg = do
toParent <- getRouteToParent
setMessage $ toHtml msg
redirect $ toParent LoginR
tryFetchCreds :: IO a -> IO (Either SomeException a)
tryFetchCreds :: IO a -> IO (Either DispatchError a)
tryFetchCreds f =
(Right <$> f)
`catch` (\(ex :: IOException) -> pure $ Left $ toException ex)
`catch` (\(ex :: YesodOAuth2Exception) -> pure $ Left $ toException ex)
`catch` (pure . Left . FetchCredsIOException)
`catch` (pure . Left . FetchCredsYesodOAuth2Exception)
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> Text
-> m OAuth2
withCallbackAndState name oauth2 csrf = do
let url = PluginR name ["callback"]
render <- getParentUrlRender
let callbackText = render url
callback <-
maybe
(liftIO
$ throwString
$ "Invalid callback URI: "
<> T.unpack callbackText
<> ". Not using an absolute Approot?"
)
pure
$ fromText callbackText
callback <- maybe (throwError $ InvalidCallbackUri callbackText) pure
$ fromText callbackText
pure oauth2
{ oauthCallback = Just callback
@ -169,40 +139,28 @@ setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF sessionKey = do
csrfToken <- liftIO randomToken
csrfToken <$ setSession sessionKey csrfToken
where
randomToken =
T.filter (/= '+')
. decodeUtf8
. convertToBase @ByteString Base64
<$> getRandomBytes 64
where randomToken = T.filter (/= '+') <$> randomText 64
-- | Verify the callback provided the same CSRF token as in our session
verifySessionCSRF :: MonadHandler m => Text -> m Text
verifySessionCSRF
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
verifySessionCSRF sessionKey = do
token <- requireGetParam "state"
sessionToken <- lookupSession sessionKey
deleteSession sessionKey
unless (sessionToken == Just token) $ do
$(logError)
$ "state token does not match. "
<> "Param: "
<> tshow token
<> "State: "
<> tshow sessionToken
permissionDenied "Invalid OAuth2 state token"
unless (sessionToken == Just token) $ throwError $ InvalidStateToken
sessionToken
token
return token
pure token
requireGetParam :: MonadHandler m => Text -> m Text
requireGetParam
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
requireGetParam key = do
m <- lookupGetParam key
maybe errInvalidArgs return m
where
errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"]
maybe err return m
where err = throwError $ MissingParameter key
tokenSessionKey :: Text -> Text
tokenSessionKey name = "_yesod_oauth2_" <> name
tshow :: Show a => a -> Text
tshow = T.pack . show

View File

@ -0,0 +1,78 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.DispatchError
( DispatchError(..)
, handleDispatchError
) where
import Control.Monad.Except
import Data.Text (Text, pack)
import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors)
import UnliftIO.Except ()
import UnliftIO.Exception
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Exception
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)
data DispatchError
= MissingParameter Text
| InvalidStateToken (Maybe Text) Text
| InvalidCallbackUri Text
| OAuth2HandshakeError ErrorResponse
| OAuth2ResultError (OAuth2Error Errors)
| FetchCredsIOException IOException
| FetchCredsYesodOAuth2Exception YesodOAuth2Exception
deriving stock Show
deriving anyclass Exception
-- | User-friendly message for any given 'DispatchError'
--
-- Most of these are opaque to the user. The exception details are present for
-- the server logs.
--
dispatchErrorMessage :: DispatchError -> Text
dispatchErrorMessage = \case
MissingParameter name ->
"Parameter '" <> name <> "' is required, but not present in the URL"
InvalidStateToken{} -> "State token is invalid, please try again"
InvalidCallbackUri{}
-> "Callback URI was not valid, this server may be misconfigured (no approot)"
OAuth2HandshakeError er -> "OAuth2 handshake failure: " <> erUserMessage er
OAuth2ResultError{} -> "Login failed, please try again"
FetchCredsIOException{} -> "Login failed, please try again"
FetchCredsYesodOAuth2Exception{} -> "Login failed, please try again"
handleDispatchError
:: MonadAuthHandler site m
=> ExceptT DispatchError m TypedContent
-> m TypedContent
handleDispatchError f = do
result <- runExceptT f
either onDispatchError pure result
onDispatchError :: MonadAuthHandler site m => DispatchError -> m TypedContent
onDispatchError err = do
errorId <- liftIO $ randomText 16
let suffix = " [errorId=" <> errorId <> "]"
$(logError) $ pack (displayException err) <> suffix
let message = dispatchErrorMessage err <> suffix
messageValue =
object ["error" .= object ["id" .= errorId, "message" .= message]]
loginR <- ($ LoginR) <$> getRouteToParent
selectRep $ do
provideRep @_ @Html $ onErrorHtml loginR message
provideRep @_ @Value $ pure messageValue

View File

@ -0,0 +1,19 @@
{-# LANGUAGE TypeApplications #-}
module Yesod.Auth.OAuth2.Random
( randomText
) where
import Crypto.Random (MonadRandom, getRandomBytes)
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
import Data.ByteString (ByteString)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
randomText
:: MonadRandom m
=> Int
-- ^ Size in Bytes (note necessarily characters)
-> m Text
randomText size =
decodeUtf8 . convertToBase @ByteString Base64 <$> getRandomBytes size