Stop using fromJust and change to throw exceptions instead.

This commit is contained in:
Hiromi Ishii 2012-12-03 22:29:48 +09:00
parent a4ae1803f7
commit 3ed5b54d65

View File

@ -1,5 +1,4 @@
{-# LANGUAGE QuasiQuotes, OverloadedStrings #-} {-# LANGUAGE DeriveDataTypeable, OverloadedStrings, QuasiQuotes #-}
{-# OPTIONS_GHC -fwarn-unused-imports #-}
module Yesod.Auth.OAuth module Yesod.Auth.OAuth
( authOAuth ( authOAuth
, oauthUrl , oauthUrl
@ -9,21 +8,28 @@ module Yesod.Auth.OAuth
, tumblrUrl , tumblrUrl
, module Web.Authenticate.OAuth , module Web.Authenticate.OAuth
) where ) where
import Control.Applicative ((<$>), (<*>))
import Control.Arrow ((***))
import Control.Exception.Lifted
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import Data.Maybe
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Typeable
import Web.Authenticate.OAuth
import Yesod.Auth import Yesod.Auth
import Yesod.Form import Yesod.Form
import Yesod.Handler import Yesod.Handler
import Yesod.Widget import Yesod.Widget
import Web.Authenticate.OAuth
import Data.Maybe data YesodOAuthException = CredentialError String Credential
import Control.Arrow ((***)) | SessionError String
import Control.Monad.IO.Class deriving (Show, Typeable)
import Data.Text (Text)
import qualified Data.Text as T instance Exception YesodOAuthException
import Data.Text.Encoding (encodeUtf8, decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.ByteString (ByteString)
import Control.Applicative ((<$>), (<*>))
oauthUrl :: Text -> AuthRoute oauthUrl :: Text -> AuthRoute
oauthUrl name = PluginR name ["forward"] oauthUrl name = PluginR name ["forward"]
@ -34,6 +40,7 @@ authOAuth :: YesodAuth m
-> AuthPlugin m -> AuthPlugin m
authOAuth oauth mkCreds = AuthPlugin name dispatch login authOAuth oauth mkCreds = AuthPlugin name dispatch login
where where
getOAuthSession = maybe (throwIO $ SessionError "") return =<< lookupSession oauthSessionName
name = T.pack $ oauthServerName oauth name = T.pack $ oauthServerName oauth
url = PluginR name [] url = PluginR name []
lookupTokenSecret = bsToText . fromMaybe "" . lookup "oauth_token_secret" . unCredential lookupTokenSecret = bsToText . fromMaybe "" . lookup "oauth_token_secret" . unCredential
@ -51,7 +58,7 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
if oauthVersion oauth == OAuth10 if oauthVersion oauth == OAuth10
then do then do
oaTok <- runInputGet $ ireq textField "oauth_token" oaTok <- runInputGet $ ireq textField "oauth_token"
tokSec <- fromJust <$> lookupSession oauthSessionName tokSec <- getOAuthSession
deleteSession oauthSessionName deleteSession oauthSessionName
return $ Credential [ ("oauth_token", encodeUtf8 oaTok) return $ Credential [ ("oauth_token", encodeUtf8 oaTok)
, ("oauth_token_secret", encodeUtf8 tokSec) , ("oauth_token_secret", encodeUtf8 tokSec)
@ -60,7 +67,7 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
(verifier, oaTok) <- (verifier, oaTok) <-
runInputGet $ (,) <$> ireq textField "oauth_verifier" runInputGet $ (,) <$> ireq textField "oauth_verifier"
<*> ireq textField "oauth_token" <*> ireq textField "oauth_token"
tokSec <- fromJust <$> lookupSession oauthSessionName tokSec <- getOAuthSession
deleteSession oauthSessionName deleteSession oauthSessionName
return $ Credential [ ("oauth_verifier", encodeUtf8 verifier) return $ Credential [ ("oauth_verifier", encodeUtf8 verifier)
, ("oauth_token", encodeUtf8 oaTok) , ("oauth_token", encodeUtf8 oaTok)
@ -74,28 +81,29 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
login tm = do login tm = do
render <- lift getUrlRender render <- lift getUrlRender
let oaUrl = render $ tm $ oauthUrl name let oaUrl = render $ tm $ oauthUrl name
addWidget
[whamlet| <a href=#{oaUrl}>Login via #{name} |] [whamlet| <a href=#{oaUrl}>Login via #{name} |]
mkExtractCreds name idName (Credential dic) = do
let mcrId = decodeUtf8With lenientDecode <$> lookup (encodeUtf8 $ T.pack idName) dic
case mcrId of
Just crId -> return $ Creds name crId $ map (bsToText *** bsToText) dic
Nothing -> throwIO $ CredentialError ("key not found: " ++ idName) (Credential dic)
authTwitter :: YesodAuth m authTwitter :: YesodAuth m
=> ByteString -- ^ Consumer Key => ByteString -- ^ Consumer Key
-> ByteString -- ^ Consumer Secret -> ByteString -- ^ Consumer Secret
-> AuthPlugin m -> AuthPlugin m
authTwitter key secret = authOAuth authTwitter key secret = authOAuth
(newOAuth { oauthServerName = "twitter" (newOAuth { oauthServerName = "twitter"
, oauthRequestUri = "https://api.twitter.com/oauth/request_token" , oauthRequestUri = "http://twitter.com/oauth/request_token"
, oauthAccessTokenUri = "https://api.twitter.com/oauth/access_token" , oauthAccessTokenUri = "http://api.twitter.com/oauth/access_token"
, oauthAuthorizeUri = "https://api.twitter.com/oauth/authorize" , oauthAuthorizeUri = "http://api.twitter.com/oauth/authorize"
, oauthSignatureMethod = HMACSHA1 , oauthSignatureMethod = HMACSHA1
, oauthConsumerKey = key , oauthConsumerKey = key
, oauthConsumerSecret = secret , oauthConsumerSecret = secret
, oauthVersion = OAuth10a , oauthVersion = OAuth10a
}) })
extractCreds (mkExtractCreds "twitter" "screen_name")
where
extractCreds (Credential dic) = do
let crId = decodeUtf8With lenientDecode $ fromJust $ lookup "screen_name" dic
return $ Creds "twitter" crId $ map (bsToText *** bsToText ) dic
twitterUrl :: AuthRoute twitterUrl :: AuthRoute
twitterUrl = oauthUrl "twitter" twitterUrl = oauthUrl "twitter"
@ -114,11 +122,7 @@ authTumblr key secret = authOAuth
, oauthConsumerSecret = secret , oauthConsumerSecret = secret
, oauthVersion = OAuth10a , oauthVersion = OAuth10a
}) })
extractCreds (mkExtractCreds "tumblr" "name")
where
extractCreds (Credential dic) = do
let crId = decodeUtf8With lenientDecode $ fromJust $ lookup "name" dic
return $ Creds "tumblr" crId $ map (bsToText *** bsToText ) dic
tumblrUrl :: AuthRoute tumblrUrl :: AuthRoute
tumblrUrl = oauthUrl "tumblr" tumblrUrl = oauthUrl "tumblr"