Stop using fromJust and change to throw exceptions instead.

This commit is contained in:
Hiromi Ishii 2012-12-03 22:29:48 +09:00
parent a4ae1803f7
commit 3ed5b54d65

View File

@ -1,39 +1,46 @@
{-# LANGUAGE QuasiQuotes, OverloadedStrings #-}
{-# OPTIONS_GHC -fwarn-unused-imports #-}
{-# LANGUAGE DeriveDataTypeable, OverloadedStrings, QuasiQuotes #-}
module Yesod.Auth.OAuth
( authOAuth
, oauthUrl
, authTwitter
, twitterUrl
, authTumblr
, tumblrUrl
, authTumblr
, tumblrUrl
, module Web.Authenticate.OAuth
) where
import Control.Applicative ((<$>), (<*>))
import Control.Arrow ((***))
import Control.Exception.Lifted
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import Data.Maybe
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Typeable
import Web.Authenticate.OAuth
import Yesod.Auth
import Yesod.Form
import Yesod.Handler
import Yesod.Widget
import Yesod.Auth
import Yesod.Form
import Yesod.Handler
import Yesod.Widget
import Web.Authenticate.OAuth
import Data.Maybe
import Control.Arrow ((***))
import Control.Monad.IO.Class
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8, decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.ByteString (ByteString)
import Control.Applicative ((<$>), (<*>))
data YesodOAuthException = CredentialError String Credential
| SessionError String
deriving (Show, Typeable)
instance Exception YesodOAuthException
oauthUrl :: Text -> AuthRoute
oauthUrl name = PluginR name ["forward"]
authOAuth :: YesodAuth m
=> OAuth -- ^ 'OAuth' data-type for signing.
-> (Credential -> IO (Creds m)) -- ^ How to extract ident.
-> (Credential -> IO (Creds m)) -- ^ How to extract ident.
-> AuthPlugin m
authOAuth oauth mkCreds = AuthPlugin name dispatch login
where
getOAuthSession = maybe (throwIO $ SessionError "") return =<< lookupSession oauthSessionName
name = T.pack $ oauthServerName oauth
url = PluginR name []
lookupTokenSecret = bsToText . fromMaybe "" . lookup "oauth_token_secret" . unCredential
@ -51,7 +58,7 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
if oauthVersion oauth == OAuth10
then do
oaTok <- runInputGet $ ireq textField "oauth_token"
tokSec <- fromJust <$> lookupSession oauthSessionName
tokSec <- getOAuthSession
deleteSession oauthSessionName
return $ Credential [ ("oauth_token", encodeUtf8 oaTok)
, ("oauth_token_secret", encodeUtf8 tokSec)
@ -60,7 +67,7 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
(verifier, oaTok) <-
runInputGet $ (,) <$> ireq textField "oauth_verifier"
<*> ireq textField "oauth_token"
tokSec <- fromJust <$> lookupSession oauthSessionName
tokSec <- getOAuthSession
deleteSession oauthSessionName
return $ Credential [ ("oauth_verifier", encodeUtf8 verifier)
, ("oauth_token", encodeUtf8 oaTok)
@ -74,8 +81,13 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
login tm = do
render <- lift getUrlRender
let oaUrl = render $ tm $ oauthUrl name
addWidget
[whamlet| <a href=#{oaUrl}>Login via #{name} |]
[whamlet| <a href=#{oaUrl}>Login via #{name} |]
mkExtractCreds name idName (Credential dic) = do
let mcrId = decodeUtf8With lenientDecode <$> lookup (encodeUtf8 $ T.pack idName) dic
case mcrId of
Just crId -> return $ Creds name crId $ map (bsToText *** bsToText) dic
Nothing -> throwIO $ CredentialError ("key not found: " ++ idName) (Credential dic)
authTwitter :: YesodAuth m
=> ByteString -- ^ Consumer Key
@ -83,19 +95,15 @@ authTwitter :: YesodAuth m
-> AuthPlugin m
authTwitter key secret = authOAuth
(newOAuth { oauthServerName = "twitter"
, oauthRequestUri = "https://api.twitter.com/oauth/request_token"
, oauthAccessTokenUri = "https://api.twitter.com/oauth/access_token"
, oauthAuthorizeUri = "https://api.twitter.com/oauth/authorize"
, oauthRequestUri = "http://twitter.com/oauth/request_token"
, oauthAccessTokenUri = "http://api.twitter.com/oauth/access_token"
, oauthAuthorizeUri = "http://api.twitter.com/oauth/authorize"
, oauthSignatureMethod = HMACSHA1
, oauthConsumerKey = key
, oauthConsumerSecret = secret
, oauthVersion = OAuth10a
})
extractCreds
where
extractCreds (Credential dic) = do
let crId = decodeUtf8With lenientDecode $ fromJust $ lookup "screen_name" dic
return $ Creds "twitter" crId $ map (bsToText *** bsToText ) dic
(mkExtractCreds "twitter" "screen_name")
twitterUrl :: AuthRoute
twitterUrl = oauthUrl "twitter"
@ -114,11 +122,7 @@ authTumblr key secret = authOAuth
, oauthConsumerSecret = secret
, oauthVersion = OAuth10a
})
extractCreds
where
extractCreds (Credential dic) = do
let crId = decodeUtf8With lenientDecode $ fromJust $ lookup "name" dic
return $ Creds "tumblr" crId $ map (bsToText *** bsToText ) dic
(mkExtractCreds "tumblr" "name")
tumblrUrl :: AuthRoute
tumblrUrl = oauthUrl "tumblr"