Stop using fromJust and change to throw exceptions instead.

This commit is contained in:
Hiromi Ishii 2012-12-03 22:29:48 +09:00
parent a4ae1803f7
commit 3ed5b54d65

View File

@ -1,39 +1,46 @@
{-# LANGUAGE QuasiQuotes, OverloadedStrings #-} {-# LANGUAGE DeriveDataTypeable, OverloadedStrings, QuasiQuotes #-}
{-# OPTIONS_GHC -fwarn-unused-imports #-}
module Yesod.Auth.OAuth module Yesod.Auth.OAuth
( authOAuth ( authOAuth
, oauthUrl , oauthUrl
, authTwitter , authTwitter
, twitterUrl , twitterUrl
, authTumblr , authTumblr
, tumblrUrl , tumblrUrl
, module Web.Authenticate.OAuth , module Web.Authenticate.OAuth
) where ) where
import Control.Applicative ((<$>), (<*>))
import Control.Arrow ((***))
import Control.Exception.Lifted
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import Data.Maybe
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Typeable
import Web.Authenticate.OAuth
import Yesod.Auth
import Yesod.Form
import Yesod.Handler
import Yesod.Widget
import Yesod.Auth data YesodOAuthException = CredentialError String Credential
import Yesod.Form | SessionError String
import Yesod.Handler deriving (Show, Typeable)
import Yesod.Widget
import Web.Authenticate.OAuth instance Exception YesodOAuthException
import Data.Maybe
import Control.Arrow ((***))
import Control.Monad.IO.Class
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8, decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.ByteString (ByteString)
import Control.Applicative ((<$>), (<*>))
oauthUrl :: Text -> AuthRoute oauthUrl :: Text -> AuthRoute
oauthUrl name = PluginR name ["forward"] oauthUrl name = PluginR name ["forward"]
authOAuth :: YesodAuth m authOAuth :: YesodAuth m
=> OAuth -- ^ 'OAuth' data-type for signing. => OAuth -- ^ 'OAuth' data-type for signing.
-> (Credential -> IO (Creds m)) -- ^ How to extract ident. -> (Credential -> IO (Creds m)) -- ^ How to extract ident.
-> AuthPlugin m -> AuthPlugin m
authOAuth oauth mkCreds = AuthPlugin name dispatch login authOAuth oauth mkCreds = AuthPlugin name dispatch login
where where
getOAuthSession = maybe (throwIO $ SessionError "") return =<< lookupSession oauthSessionName
name = T.pack $ oauthServerName oauth name = T.pack $ oauthServerName oauth
url = PluginR name [] url = PluginR name []
lookupTokenSecret = bsToText . fromMaybe "" . lookup "oauth_token_secret" . unCredential lookupTokenSecret = bsToText . fromMaybe "" . lookup "oauth_token_secret" . unCredential
@ -51,7 +58,7 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
if oauthVersion oauth == OAuth10 if oauthVersion oauth == OAuth10
then do then do
oaTok <- runInputGet $ ireq textField "oauth_token" oaTok <- runInputGet $ ireq textField "oauth_token"
tokSec <- fromJust <$> lookupSession oauthSessionName tokSec <- getOAuthSession
deleteSession oauthSessionName deleteSession oauthSessionName
return $ Credential [ ("oauth_token", encodeUtf8 oaTok) return $ Credential [ ("oauth_token", encodeUtf8 oaTok)
, ("oauth_token_secret", encodeUtf8 tokSec) , ("oauth_token_secret", encodeUtf8 tokSec)
@ -60,7 +67,7 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
(verifier, oaTok) <- (verifier, oaTok) <-
runInputGet $ (,) <$> ireq textField "oauth_verifier" runInputGet $ (,) <$> ireq textField "oauth_verifier"
<*> ireq textField "oauth_token" <*> ireq textField "oauth_token"
tokSec <- fromJust <$> lookupSession oauthSessionName tokSec <- getOAuthSession
deleteSession oauthSessionName deleteSession oauthSessionName
return $ Credential [ ("oauth_verifier", encodeUtf8 verifier) return $ Credential [ ("oauth_verifier", encodeUtf8 verifier)
, ("oauth_token", encodeUtf8 oaTok) , ("oauth_token", encodeUtf8 oaTok)
@ -74,8 +81,13 @@ authOAuth oauth mkCreds = AuthPlugin name dispatch login
login tm = do login tm = do
render <- lift getUrlRender render <- lift getUrlRender
let oaUrl = render $ tm $ oauthUrl name let oaUrl = render $ tm $ oauthUrl name
addWidget [whamlet| <a href=#{oaUrl}>Login via #{name} |]
[whamlet| <a href=#{oaUrl}>Login via #{name} |]
mkExtractCreds name idName (Credential dic) = do
let mcrId = decodeUtf8With lenientDecode <$> lookup (encodeUtf8 $ T.pack idName) dic
case mcrId of
Just crId -> return $ Creds name crId $ map (bsToText *** bsToText) dic
Nothing -> throwIO $ CredentialError ("key not found: " ++ idName) (Credential dic)
authTwitter :: YesodAuth m authTwitter :: YesodAuth m
=> ByteString -- ^ Consumer Key => ByteString -- ^ Consumer Key
@ -83,19 +95,15 @@ authTwitter :: YesodAuth m
-> AuthPlugin m -> AuthPlugin m
authTwitter key secret = authOAuth authTwitter key secret = authOAuth
(newOAuth { oauthServerName = "twitter" (newOAuth { oauthServerName = "twitter"
, oauthRequestUri = "https://api.twitter.com/oauth/request_token" , oauthRequestUri = "http://twitter.com/oauth/request_token"
, oauthAccessTokenUri = "https://api.twitter.com/oauth/access_token" , oauthAccessTokenUri = "http://api.twitter.com/oauth/access_token"
, oauthAuthorizeUri = "https://api.twitter.com/oauth/authorize" , oauthAuthorizeUri = "http://api.twitter.com/oauth/authorize"
, oauthSignatureMethod = HMACSHA1 , oauthSignatureMethod = HMACSHA1
, oauthConsumerKey = key , oauthConsumerKey = key
, oauthConsumerSecret = secret , oauthConsumerSecret = secret
, oauthVersion = OAuth10a , oauthVersion = OAuth10a
}) })
extractCreds (mkExtractCreds "twitter" "screen_name")
where
extractCreds (Credential dic) = do
let crId = decodeUtf8With lenientDecode $ fromJust $ lookup "screen_name" dic
return $ Creds "twitter" crId $ map (bsToText *** bsToText ) dic
twitterUrl :: AuthRoute twitterUrl :: AuthRoute
twitterUrl = oauthUrl "twitter" twitterUrl = oauthUrl "twitter"
@ -114,11 +122,7 @@ authTumblr key secret = authOAuth
, oauthConsumerSecret = secret , oauthConsumerSecret = secret
, oauthVersion = OAuth10a , oauthVersion = OAuth10a
}) })
extractCreds (mkExtractCreds "tumblr" "name")
where
extractCreds (Credential dic) = do
let crId = decodeUtf8With lenientDecode $ fromJust $ lookup "name" dic
return $ Creds "tumblr" crId $ map (bsToText *** bsToText ) dic
tumblrUrl :: AuthRoute tumblrUrl :: AuthRoute
tumblrUrl = oauthUrl "tumblr" tumblrUrl = oauthUrl "tumblr"