-- SPDX-FileCopyrightText: 2023-2024 Sarah Vaupel , David Mosbach -- -- SPDX-License-Identifier: AGPL-3.0-or-later {-# OPTIONS_GHC -fno-warn-orphans #-} module Auth.OAuth2 ( apAzure , azurePrimaryKey, azureUserPrincipalName, azureUserDisplayName, azureUserGivenName, azureUserSurname, azureUserMail, azureUserTelephone, azureUserMobile, azureUserPreferredLanguage , azureUser, azureUser' , AzureUserException(..), _AzureUserError, _AzureUserNoResult, _AzureUserAmbiguous , apAzureMock , azureMockServer , queryOAuth2User , refreshOAuth2Token ) where import qualified Data.CaseInsensitive as CI import Data.Maybe (fromJust) import Data.Text import Import.NoFoundation hiding (unpack) import Network.HTTP.Simple (httpJSONEither, getResponseBody, JSONException) import System.Environment (lookupEnv) import Yesod.Auth.OAuth2 import Yesod.Auth.OAuth2.Prelude hiding (encodeUtf8) -- | Plugin name of the OAuth2 yesod plugin for Azure ADv2 apAzure :: Text apAzure = "AzureADv2" data AzureUserException = AzureUserError | AzureUserNoResult | AzureUserAmbiguous deriving (Show, Eq, Generic) instance Exception AzureUserException makePrisms ''AzureUserException azurePrimaryKey, azureUserPrincipalName, azureUserDisplayName, azureUserGivenName, azureUserSurname, azureUserMail, azureUserTelephone, azureUserMobile, azureUserPreferredLanguage :: Text azurePrimaryKey = "id" azureUserPrincipalName = "userPrincipalName" azureUserDisplayName = "displayName" azureUserGivenName = "givenName" azureUserSurname = "surname" azureUserMail = "mail" azureUserTelephone = "businessPhones" azureUserMobile = "mobilePhone" azureUserPreferredLanguage = "preferredLanguage" -- | User lookup in Microsoft Graph with given credentials azureUser :: ( MonadMask m , MonadUnliftIO m -- , MonadThrow m ) => AzureConf -> Creds site -> m [(Text, [ByteString])] -- (Either AzureUserException [(Text, [ByteString])]) azureUser _conf _creds = fmap throwLeft . liftIO . runExceptT $ do results <- return [] -- TODO case results of [] -> throwE AzureUserNoResult [res] -> return res _multiple -> throwE AzureUserAmbiguous -- | User lookup in Microsoft Graph with given user azureUser' :: ( MonadMask m , MonadUnliftIO m -- , MonadThrow m ) => AzureConf -> User -> m (Maybe [(Text, [ByteString])]) -- (Either AzureUserException [(Text, [ByteString])]) azureUser' conf User{userIdent} = runMaybeT . catchIfMaybeT (is _AzureUserNoResult) $ azureUser conf (Creds apAzure (CI.original userIdent) []) ---------------------------------------- ---- OAuth2 development auth plugin ---- ---------------------------------------- apAzureMock :: Text apAzureMock = "uniworx_dev" newtype UserID = UserID Text instance FromJSON UserID where parseJSON = withObject "UserID" $ \o -> UserID <$> o .: "id" azureMockServer :: YesodAuth m => String -> AuthPlugin m azureMockServer port = let oa = OAuth2 { oauth2ClientId = "42" , oauth2ClientSecret = Just "shhh" , oauth2AuthorizeEndpoint = (fromString $ mockServerURL <> "/auth") `withQuery` [scopeParam " " ["ID", "Profile"]] , oauth2TokenEndpoint = fromString $ mockServerURL <> "/token" , oauth2RedirectUri = Nothing } mockServerURL = "http://localhost:" <> fromString port profileSrc = fromString $ mockServerURL <> "/users/me" in authOAuth2 apAzureMock oa $ \manager token -> do (UserID userID, userResponse) <- authGetProfile apAzureMock manager token profileSrc return Creds { credsPlugin = apAzureMock , credsIdent = userID , credsExtra = setExtra token userResponse } ---------------------- ---- User Queries ---- ---------------------- data UserDataException = UserDataJSONException JSONException | UserDataInternalException Text deriving Show instance Exception UserDataException queryOAuth2User :: forall j m. ( FromJSON j , MonadHandler m , MonadThrow m ) => Text -- ^ User identifier (arbitrary needle) -> m (Either UserDataException j) queryOAuth2User userID = runExceptT $ do (queryUrl, tokenUrl) <- liftIO mkBaseUrls req <- parseRequest $ "GET " ++ queryUrl ++ unpack userID mTokens <- lookupSessionJson SessionOAuth2Token unless (isJust mTokens) . throwE $ UserDataInternalException "Tried to load session Oauth2 tokens, but there are none" # ifdef DEVELOPMENT let secure = False # else let secure = True # endif newTokens <- refreshOAuth2Token @m (fromJust mTokens) tokenUrl secure setSessionJson SessionOAuth2Token (Just $ accessToken newTokens, refreshToken newTokens) eResult <- lift $ getResponseBody <$> httpJSONEither @m @j (req { secure = secure , requestHeaders = [("Authorization", encodeUtf8 . ("Bearer " <>) . atoken $ accessToken newTokens)] }) case eResult of Left x -> throwE $ UserDataJSONException x Right x -> return x mkBaseUrls :: IO (String, String) mkBaseUrls = do # ifndef DEVELOPMENT Just tenantID <- lookupEnv "AZURE_TENANT_ID" return ( "https://graph.microsoft.com/v1.0/users/" , "https://login.microsoftonline.com/" ++ tenantID ++ "/oauth2/v2.0" ) # else Just port <- lookupEnv "OAUTH2_SERVER_PORT" let base = "http://localhost:" ++ port return ( base ++ "/users/query?id=" , base ++ "/token" ) # endif refreshOAuth2Token :: forall m. ( MonadHandler m , MonadThrow m ) => (Maybe AccessToken, Maybe RefreshToken) -> String -> Bool -> ExceptT UserDataException m OAuth2Token refreshOAuth2Token (_, rToken) url secure | isJust rToken = do req <- parseRequest $ "POST " ++ url let body = [ ("grant_type", "refresh_token") , ("refresh_token", encodeUtf8 . rtoken $ fromJust rToken) ] body' <- if secure then do clientID <- liftIO $ fromJust <$> lookupEnv "CLIENT_ID" clientSecret <- liftIO $ fromJust <$> lookupEnv "CLIENT_SECRET" return $ body ++ [("client_id", fromString clientID), ("client_secret", fromString clientSecret), ("scope", "openid profile")] else return $ ("scope", "ID Profile") : body $logErrorS "\27[31mAdmin Handler\27[0m" $ tshow (requestBody $ urlEncodedBody body' req{ secure = secure }) eResult <- lift $ getResponseBody <$> httpJSONEither @m @OAuth2Token (urlEncodedBody body' req{ secure = secure }) case eResult of Left x -> throwE $ UserDataJSONException x Right x -> return x | otherwise = throwE $ UserDataInternalException "Could not refresh access token. Refresh token is missing." instance Show RequestBody where show (RequestBodyLBS x) = show x show _ = error ":("