fradrive/src/Foundation/Servant.hs
2022-10-12 09:35:16 +02:00

200 lines
8.6 KiB
Haskell

-- SPDX-FileCopyrightText: 2022 Sarah Vaupel <sarah.vaupel@ifi.lmu.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE UndecidableInstances, InstanceSigs #-}
module Foundation.Servant
( ServantApiDispatchUniWorX(..)
, UniWorXContext
, ServantHandler, ServantDB
) where
import Import.Servant.NoFoundation
import Foundation.DB (runSqlPoolRetry')
import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun)
import Foundation.Instances ()
import qualified Data.HashMap.Strict.InsOrd as HashMap
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders)
import qualified Network.Wai as W
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal)
import qualified Yesod.Servant as Servant
import qualified Data.Text as Text
import Control.Monad.Catch.Pure
import Servant.Server.Internal.Delayed
import Servant.Server.Internal.Router
import Database.Persist.Sql (transactionUndo)
import qualified Data.CaseInsensitive as CI
instance ( HasServer sub context
, ToJSON restr, FromJSON restr
, SBoolI (FoldRequired mods)
, HasContextEntry context (Maybe (BearerToken UniWorX))
, HasContextEntry context (Maybe (Route UniWorX))
)
=> HasServer (CaptureBearerRestriction' mods restr :> sub) context
where
type ServerT (CaptureBearerRestriction' mods restr :> sub) m
= RequiredArgument mods restr -> ServerT sub m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
where
bearerCheck :: DelayedIO (RequiredArgument mods restr)
bearerCheck = do
let bearer :: Maybe (BearerToken UniWorX)
bearer = getContextEntry context
cRoute :: Maybe (Route UniWorX)
cRoute = getContextEntry context
noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." }
noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." }
exceptT delayedFailFatal return $ do
cRoute' <- maybeExceptT' noRouteStored cRoute
let mbRet :: Maybe (Maybe restr)
mbRet = bearer <&> preview (_bearerRestrictionIx cRoute')
case sbool @(FoldRequired mods) of
SFalse -> return $ join mbRet
STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet
instance ( HasServer sub context
, SBoolI (FoldRequired mods)
, HasContextEntry context (Maybe (BearerToken UniWorX))
)
=> HasServer (CaptureBearerToken' mods :> sub) context
where
type ServerT (CaptureBearerToken' mods :> sub) m
= RequiredArgument mods (BearerToken UniWorX) -> ServerT sub m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
where
bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX))
bearerCheck = do
let bearer :: Maybe (BearerToken UniWorX)
bearer = getContextEntry context
noTokenProvided :: ServerError
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
exceptT delayedFailFatal return $ do
case sbool @(FoldRequired mods) of
SFalse -> return bearer
STrue -> maybe (throwE noTokenProvided) return bearer
instance ( HasServer sub context
, HasCryptoID ciphertext plaintext (ReaderT CryptoIDKey Catch)
, SBoolI (FoldLenient mods)
, FromHttpApiData ciphertext
, HasContextEntry context UniWorX
) => HasServer (CaptureCryptoID' mods ciphertext sym plaintext :> sub) context where
type ServerT (CaptureCryptoID' mods ciphertext sym plaintext :> sub) m
= If (FoldLenient mods) (Either String plaintext) plaintext -> ServerT sub m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver = CaptureRouter .
route (Proxy @sub) context . addCapture subserver $ \txt -> case ( sbool :: SBool (FoldLenient mods)
, decrypt' <$> parseUrlPiece txt
) of
(SFalse, Left e ) -> delayedFail err400{ errBody = fromStrict $ encodeUtf8 e }
(SFalse, Right (Left _ )) -> delayedFail err400{ errBody = "Could not decrypt CryptoID" }
(SFalse, Right (Right pID)) -> return pID
(STrue, join -> piece) -> return $ left unpack piece
where
decrypt' :: CryptoID ciphertext plaintext -> Either Text plaintext
decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context
type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[]
type ServantHandler = ServantHandlerFor UniWorX
type ServantDB = ServantDBFor UniWorX
deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX
instance HasServantHandlerContext UniWorX where
data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX
{ usctxSite :: UniWorX
, usctxRequest :: W.Request
, usctxIsDryRun :: IsDryRun
}
getSCtxSite = usctxSite
getSCtxRequest = usctxRequest
class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where
servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler
instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where
servantContext _ app _ = do
isDryRun' <- MkIsDryRun <$> isDryRun
restr <- maybeBearerToken
cRoute <- getCurrentRoute
return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext
servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor
servantMiddleware _ _ ctx = appEndo . foldMap Endo $
guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader)
++ [ modifyResponse (mapResponseHeaders setDefaultHeaders)
, fixTrailingSlash
]
servantYesodMiddleware _ _ = return id
servantServer proxy _ = servantServer' proxy
setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders
setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders
where defaultHeaders = HashMap.fromList
[ ("X-Frame-Options", "sameorigin")
, ("X-Content-Type-Options", "nosniff")
, ("Vary", "Accept")
, ("X-XSS-Protection", "1; mode=block")
]
setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True)
fixTrailingSlash :: Middleware
-- ^ `servant-server` contains a special case in their implementation
-- of `runRouter`, that discards trailing slashes.
--
-- Because all slashes matter, this duplicates trailing slashes.
fixTrailingSlash = (. fixTrailingSlash')
where fixTrailingSlash' req
| Just pathInfo' <- fromNullable $ W.pathInfo req
, Text.null $ last pathInfo'
= req { W.pathInfo = W.pathInfo req ++ [Text.empty] }
| otherwise
= req
instance ServantPersist UniWorX where
runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
runDB = runDB' callStack
runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
runDB' lbl action = do
$logDebugS "ServantPersist" "runDB"
MkIsDryRun dryRun <- getsServantContext usctxIsDryRun
let action'
| dryRun = action <* transactionUndo
| otherwise = action
flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite