Redirect on OAuth2 errors, not permissionDenied

This commit is contained in:
patrick brisbin 2018-03-27 17:19:51 -04:00
parent 66317cae11
commit e025854e52
No known key found for this signature in database
GPG Key ID: 4243EA839B9CC425
2 changed files with 41 additions and 11 deletions

View File

@ -1,4 +1,5 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecordWildCards #-}
@ -19,9 +20,9 @@ import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2 import Network.OAuth.OAuth2
import System.Random (newStdGen, randomRs) import System.Random (newStdGen, randomRs)
import URI.ByteString.Extension import URI.ByteString.Extension
import Yesod.Auth import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.ErrorResponse (onErrorResponse) import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Core import Yesod.Core hiding (ErrorResponse)
-- | How to take an @'OAuth2Token'@ and retrieve user credentials -- | How to take an @'OAuth2Token'@ and retrieve user credentials
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m) type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)
@ -64,18 +65,23 @@ dispatchCallback name oauth2 getCreds = do
code <- requireGetParam "code" code <- requireGetParam "code"
manager <- authHttpManager manager <- authHttpManager
oauth2' <- withCallbackAndState name oauth2 csrf oauth2' <- withCallbackAndState name oauth2 csrf
token <- denyLeft $ fetchAccessToken manager oauth2' $ ExchangeToken code token <- errLeft $ fetchAccessToken manager oauth2' $ ExchangeToken code
creds <- denyLeft $ tryIO $ getCreds manager token creds <- errLeft $ tryIO $ getCreds manager token
setCredsRedirect creds setCredsRedirect creds
where where
-- On a Left result, log it and return an opaque permission-denied errLeft :: Show e => IO (Either e a) -> AuthHandler m a
denyLeft :: (MonadHandler m, MonadLogger m, Show e) => IO (Either e a) -> m a errLeft = either (errInvalidOAuth . unknownError . tshow) pure <=< liftIO
denyLeft = either errInvalidOAuth pure <=< liftIO
errInvalidOAuth :: (MonadHandler m, MonadLogger m, Show e) => e -> m a errInvalidOAuth :: ErrorResponse -> AuthHandler m a
errInvalidOAuth err = do errInvalidOAuth err = do
$(logError) $ T.pack $ "OAuth2 error: " <> show err $(logError) $ "OAuth2 error (" <> name <> "): " <> tshow err
permissionDenied "Invalid OAuth2 authentication attempt" redirectMessage $ "Unable to log in with OAuth2: " <> erUserMessage err
redirectMessage :: Text -> AuthHandler m a
redirectMessage msg = do
toParent <- getRouteToParent
setMessage $ toHtml msg
redirect $ toParent LoginR
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2 withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState name oauth2 csrf = do withCallbackAndState name oauth2 csrf = do
@ -132,3 +138,6 @@ requireGetParam key = do
tokenSessionKey :: Text -> Text tokenSessionKey :: Text -> Text
tokenSessionKey name = "_yesod_oauth2_" <> name tokenSessionKey name = "_yesod_oauth2_" <> name
tshow :: Show a => a -> Text
tshow = T.pack . show

View File

@ -5,8 +5,10 @@
-- --
module Yesod.Auth.OAuth2.ErrorResponse module Yesod.Auth.OAuth2.ErrorResponse
( ErrorResponse(..) ( ErrorResponse(..)
, erUserMessage
, ErrorName(..) , ErrorName(..)
, onErrorResponse , onErrorResponse
, unknownError
) where ) where
import Data.Foldable (traverse_) import Data.Foldable (traverse_)
@ -32,6 +34,25 @@ data ErrorResponse = ErrorResponse
} }
deriving Show deriving Show
-- | Textual value suitable for display to a User
erUserMessage :: ErrorResponse -> Text
erUserMessage err = case erName err of
InvalidRequest -> "Invalid request"
UnauthorizedClient -> "Unauthorized client"
AccessDenied -> "Access denied"
UnsupportedResponseType -> "Unsupported response type"
InvalidScope -> "Invalid scope"
ServerError -> "Server error"
TemporarilyUnavailable -> "Temporarily unavailable"
Unknown _ -> "Unknown error"
unknownError :: Text -> ErrorResponse
unknownError x = ErrorResponse
{ erName = Unknown x
, erDescription = Nothing
, erURI = Nothing
}
-- | Check query parameters for an error, if found run the given action -- | Check query parameters for an error, if found run the given action
-- --
-- The action is expected to use a short-circuit response function like -- The action is expected to use a short-circuit response function like