diff --git a/src/Yesod/Auth/OAuth2/Dispatch.hs b/src/Yesod/Auth/OAuth2/Dispatch.hs index b1565f1..c044a8c 100644 --- a/src/Yesod/Auth/OAuth2/Dispatch.hs +++ b/src/Yesod/Auth/OAuth2/Dispatch.hs @@ -19,6 +19,7 @@ import Network.OAuth.OAuth2 import System.Random (newStdGen, randomRs) import URI.ByteString.Extension import Yesod.Auth +import Yesod.Auth.OAuth2.ErrorResponse (onErrorResponse) import Yesod.Core -- | How to take an @'OAuth2Token'@ and retrieve user credentials @@ -56,6 +57,7 @@ dispatchForward name oauth2 = do dispatchCallback :: Text -> OAuth2 -> FetchCreds m -> AuthHandler m TypedContent dispatchCallback name oauth2 getCreds = do csrf <- verifySessionCSRF $ tokenSessionKey name + onErrorResponse errInvalidOAuth code <- requireGetParam "code" manager <- lift $ getsYesod authHttpManager oauth2' <- withCallbackAndState name oauth2 csrf diff --git a/src/Yesod/Auth/OAuth2/ErrorResponse.hs b/src/Yesod/Auth/OAuth2/ErrorResponse.hs new file mode 100644 index 0000000..de1770a --- /dev/null +++ b/src/Yesod/Auth/OAuth2/ErrorResponse.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE OverloadedStrings #-} +-- | OAuth callback error response +-- +-- +-- +module Yesod.Auth.OAuth2.ErrorResponse + ( ErrorResponse(..) + , ErrorName(..) + , onErrorResponse + ) where + +import Data.Foldable (traverse_) +import Data.Text (Text) +import Data.Traversable (for) +import Yesod.Core (MonadHandler, lookupGetParam) + +data ErrorName + = InvalidRequest + | UnauthorizedClient + | AccessDenied + | UnsupportedResponseType + | InvalidScope + | ServerError + | TemporarilyUnavailable + | Unknown Text + deriving Show + +data ErrorResponse = ErrorResponse + { erName :: ErrorName + , erDescription :: Maybe Text + , erURI :: Maybe Text + } + deriving Show + +-- | Check query parameters for an error, if found run the given action +-- +-- The action is expected to use a short-circuit response function like +-- @'permissionDenied'@, hence this returning @()@. +-- +onErrorResponse :: MonadHandler m => (ErrorResponse -> m a) -> m () +onErrorResponse f = traverse_ f =<< checkErrorResponse + +checkErrorResponse :: MonadHandler m => m (Maybe ErrorResponse) +checkErrorResponse = do + merror <- lookupGetParam "error" + + for merror $ \err -> ErrorResponse + <$> pure (readErrorName err) + <*> lookupGetParam "error_description" + <*> lookupGetParam "error_uri" + +readErrorName :: Text -> ErrorName +readErrorName "invalid_request" = InvalidRequest +readErrorName "unauthorized_client" = UnauthorizedClient +readErrorName "access_denied" = AccessDenied +readErrorName "unsupported_response_type" = UnsupportedResponseType +readErrorName "invalid_scope" = InvalidScope +readErrorName "server_error" = ServerError +readErrorName "temporarily_unavailable" = TemporarilyUnavailable +readErrorName x = Unknown x