mirror of
https://github.com/freckle/yesod-auth-oauth2.git
synced 2026-01-28 20:00:24 +01:00
- Latest LTS-11.5 - Allow hoauth2-1.7, needs to be extra-dep though - Support *and require* yesod-1.6 This required: - Less lifts - HandlerFor, WidgetFor, etc - Lost MonadThrow, but can use MonadIO instead
132 lines
4.6 KiB
Haskell
132 lines
4.6 KiB
Haskell
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE RecordWildCards #-}
|
|
{-# LANGUAGE TemplateHaskell #-}
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
module Yesod.Auth.OAuth2.Dispatch
|
|
( FetchCreds
|
|
, dispatchAuthRequest
|
|
) where
|
|
|
|
import Control.Exception.Safe (throwString, tryIO)
|
|
import Control.Monad (unless, (<=<))
|
|
import Data.Monoid ((<>))
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding (encodeUtf8)
|
|
import Network.HTTP.Conduit (Manager)
|
|
import Network.OAuth.OAuth2
|
|
import System.Random (newStdGen, randomRs)
|
|
import URI.ByteString.Extension
|
|
import Yesod.Auth
|
|
import Yesod.Auth.OAuth2.ErrorResponse (onErrorResponse)
|
|
import Yesod.Core
|
|
|
|
-- | How to take an @'OAuth2Token'@ and retrieve user credentials
|
|
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)
|
|
|
|
-- | Dispatch the various OAuth2 handshake routes
|
|
dispatchAuthRequest
|
|
:: Text -- ^ Name
|
|
-> OAuth2 -- ^ Service details
|
|
-> FetchCreds m -- ^ How to get credentials
|
|
-> Text -- ^ Method
|
|
-> [Text] -- ^ Path pieces
|
|
-> AuthHandler m TypedContent
|
|
dispatchAuthRequest name oauth2 _ "GET" ["forward"] = dispatchForward name oauth2
|
|
dispatchAuthRequest name oauth2 getCreds "GET" ["callback"] = dispatchCallback name oauth2 getCreds
|
|
dispatchAuthRequest _ _ _ _ _ = notFound
|
|
|
|
-- | Handle @GET \/forward@
|
|
--
|
|
-- 1. Set a random CSRF token in our session
|
|
-- 2. Redirect to the Provider's authorization URL
|
|
--
|
|
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
|
|
dispatchForward name oauth2 = do
|
|
csrf <- setSessionCSRF $ tokenSessionKey name
|
|
oauth2' <- withCallbackAndState name oauth2 csrf
|
|
redirect $ toText $ authorizationUrl oauth2'
|
|
|
|
-- | Handle @GET \/callback@
|
|
--
|
|
-- 1. Verify the URL's CSRF token matches our session
|
|
-- 2. Use the code parameter to fetch an AccessToken for the Provider
|
|
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
|
|
--
|
|
dispatchCallback :: Text -> OAuth2 -> FetchCreds m -> AuthHandler m TypedContent
|
|
dispatchCallback name oauth2 getCreds = do
|
|
csrf <- verifySessionCSRF $ tokenSessionKey name
|
|
onErrorResponse errInvalidOAuth
|
|
code <- requireGetParam "code"
|
|
manager <- authHttpManager
|
|
oauth2' <- withCallbackAndState name oauth2 csrf
|
|
token <- denyLeft $ fetchAccessToken manager oauth2' $ ExchangeToken code
|
|
creds <- denyLeft $ tryIO $ getCreds manager token
|
|
setCredsRedirect creds
|
|
where
|
|
-- On a Left result, log it and return an opaque permission-denied
|
|
denyLeft :: (MonadHandler m, MonadLogger m, Show e) => IO (Either e a) -> m a
|
|
denyLeft = either errInvalidOAuth pure <=< liftIO
|
|
|
|
errInvalidOAuth :: (MonadHandler m, MonadLogger m, Show e) => e -> m a
|
|
errInvalidOAuth err = do
|
|
$(logError) $ T.pack $ "OAuth2 error: " <> show err
|
|
permissionDenied "Invalid OAuth2 authentication attempt"
|
|
|
|
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
|
|
withCallbackAndState name oauth2 csrf = do
|
|
let url = PluginR name ["callback"]
|
|
render <- getParentUrlRender
|
|
let callbackText = render url
|
|
|
|
callback <- maybe
|
|
(liftIO
|
|
$ throwString
|
|
$ "Invalid callback URI: "
|
|
<> T.unpack callbackText
|
|
<> ". Not using an absolute Approot?"
|
|
) pure $ fromText callbackText
|
|
|
|
pure oauth2
|
|
{ oauthCallback = Just callback
|
|
, oauthOAuthorizeEndpoint = oauthOAuthorizeEndpoint oauth2
|
|
`withQuery` [("state", encodeUtf8 csrf)]
|
|
}
|
|
|
|
getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
|
|
getParentUrlRender = (.)
|
|
<$> getUrlRender
|
|
<*> getRouteToParent
|
|
|
|
-- | Set a random, 30-character value in the session
|
|
setSessionCSRF :: MonadHandler m => Text -> m Text
|
|
setSessionCSRF sessionKey = do
|
|
csrfToken <- liftIO randomToken
|
|
csrfToken <$ setSession sessionKey csrfToken
|
|
where
|
|
randomToken = T.pack . take 30 . randomRs ('a', 'z') <$> newStdGen
|
|
|
|
-- | Verify the callback provided the same CSRF token as in our session
|
|
verifySessionCSRF :: MonadHandler m => Text -> m Text
|
|
verifySessionCSRF sessionKey = do
|
|
token <- requireGetParam "state"
|
|
sessionToken <- lookupSession sessionKey
|
|
deleteSession sessionKey
|
|
|
|
unless (sessionToken == Just token)
|
|
$ permissionDenied "Invalid OAuth2 state token"
|
|
|
|
return token
|
|
|
|
requireGetParam :: MonadHandler m => Text -> m Text
|
|
requireGetParam key = do
|
|
m <- lookupGetParam key
|
|
maybe errInvalidArgs return m
|
|
where
|
|
errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"]
|
|
|
|
tokenSessionKey :: Text -> Text
|
|
tokenSessionKey name = "_yesod_oauth2_" <> name
|