This repository has been archived on 2024-10-24. You can view files and clone it, but cannot push or open issues or pull requests.
fradrive-old/src/Foundation/Servant.hs

204 lines
8.8 KiB
Haskell

-- SPDX-FileCopyrightText: 2022 Sarah Vaupel <sarah.vaupel@ifi.lmu.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE UndecidableInstances, InstanceSigs #-}
module Foundation.Servant
( ServantApiDispatchUniWorX(..)
, UniWorXContext
, ServantHandler, ServantDB
) where
import Import.Servant.NoFoundation
import Foundation.DB (runSqlPoolRetry')
import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun)
import Foundation.Instances ()
import qualified Data.HashMap.Strict.InsOrd as HashMap
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders)
import qualified Network.Wai as W
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal)
import qualified Yesod.Servant as Servant
import qualified Data.Text as Text
import Control.Monad.Catch.Pure
import Servant.Server.Internal.Delayed
import Servant.Server.Internal.Router
import Database.Persist.Sql (transactionUndo)
import qualified Data.CaseInsensitive as CI
import GHC.TypeLits (symbolVal)
import Data.Typeable
instance ( HasServer sub context
, ToJSON restr, FromJSON restr
, SBoolI (FoldRequired mods)
, HasContextEntry context (Maybe (BearerToken UniWorX))
, HasContextEntry context (Maybe (Route UniWorX))
)
=> HasServer (CaptureBearerRestriction' mods restr :> sub) context
where
type ServerT (CaptureBearerRestriction' mods restr :> sub) m
= RequiredArgument mods restr -> ServerT sub m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
where
bearerCheck :: DelayedIO (RequiredArgument mods restr)
bearerCheck = do
let bearer :: Maybe (BearerToken UniWorX)
bearer = getContextEntry context
cRoute :: Maybe (Route UniWorX)
cRoute = getContextEntry context
noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." }
noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." }
exceptT delayedFailFatal return $ do
cRoute' <- maybeExceptT' noRouteStored cRoute
let mbRet :: Maybe (Maybe restr)
mbRet = bearer <&> preview (_bearerRestrictionIx cRoute')
case sbool @(FoldRequired mods) of
SFalse -> return $ join mbRet
STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet
instance ( HasServer sub context
, SBoolI (FoldRequired mods)
, HasContextEntry context (Maybe (BearerToken UniWorX))
)
=> HasServer (CaptureBearerToken' mods :> sub) context
where
type ServerT (CaptureBearerToken' mods :> sub) m
= RequiredArgument mods (BearerToken UniWorX) -> ServerT sub m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
where
bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX))
bearerCheck = do
let bearer :: Maybe (BearerToken UniWorX)
bearer = getContextEntry context
noTokenProvided :: ServerError
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
exceptT delayedFailFatal return $ do
case sbool @(FoldRequired mods) of
SFalse -> return bearer
STrue -> maybe (throwE noTokenProvided) return bearer
instance ( HasServer sub context
, HasCryptoID ciphertext plaintext (ReaderT CryptoIDKey Catch)
, SBoolI (FoldLenient mods)
, FromHttpApiData ciphertext
, HasContextEntry context UniWorX
, KnownSymbol sym
) => HasServer (CaptureCryptoID' mods ciphertext sym plaintext :> sub) context where
type ServerT (CaptureCryptoID' mods ciphertext sym plaintext :> sub) m
= If (FoldLenient mods) (Either String plaintext) plaintext -> ServerT sub m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver = CaptureRouter [hint] .
route (Proxy @sub) context . addCapture subserver $ \txt -> case ( sbool :: SBool (FoldLenient mods)
, decrypt' <$> parseUrlPiece txt
) of
(SFalse, Left e ) -> delayedFail err400{ errBody = fromStrict $ encodeUtf8 e }
(SFalse, Right (Left _ )) -> delayedFail err400{ errBody = "Could not decrypt CryptoID" }
(SFalse, Right (Right pID)) -> return pID
(STrue, join -> piece) -> return $ left unpack piece
where
decrypt' :: CryptoID ciphertext plaintext -> Either Text plaintext
decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context
hint = CaptureHint (Text.pack $ symbolVal $ Proxy @sym) (typeRep (Proxy :: Proxy sym)) -- from Servant.Server.Internal and modified for our usage
type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[]
type ServantHandler = ServantHandlerFor UniWorX
type ServantDB = ServantDBFor UniWorX
deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX
instance HasServantHandlerContext UniWorX where
data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX
{ usctxSite :: UniWorX
, usctxRequest :: W.Request
, usctxIsDryRun :: IsDryRun
}
getSCtxSite = usctxSite
getSCtxRequest = usctxRequest
class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where
servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler
instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where
servantContext _ app _ = do
isDryRun' <- MkIsDryRun <$> isDryRun
restr <- maybeBearerToken
cRoute <- getCurrentRoute
return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext
servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor
servantMiddleware _ _ ctx = appEndo . foldMap Endo $
guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader)
++ [ modifyResponse (mapResponseHeaders setDefaultHeaders)
, fixTrailingSlash
]
servantYesodMiddleware _ _ = return id
servantServer proxy _ = servantServer' proxy
setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders
setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders
where defaultHeaders = HashMap.fromList
[ ("X-Frame-Options", "sameorigin")
, ("X-Content-Type-Options", "nosniff")
, ("Vary", "Accept")
, ("X-XSS-Protection", "1; mode=block")
]
setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True)
fixTrailingSlash :: Middleware
-- ^ `servant-server` contains a special case in their implementation
-- of `runRouter`, that discards trailing slashes.
--
-- Because all slashes matter, this duplicates trailing slashes.
fixTrailingSlash = (. fixTrailingSlash')
where fixTrailingSlash' req
| Just pathInfo' <- fromNullable $ W.pathInfo req
, Text.null $ last pathInfo'
= req { W.pathInfo = W.pathInfo req ++ [Text.empty] }
| otherwise
= req
instance ServantPersist UniWorX where
runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
runDB = runDB' callStack
runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
runDB' lbl action = do
$logDebugS "ServantPersist" "runDB"
MkIsDryRun dryRun <- getsServantContext usctxIsDryRun
let action'
| dryRun = action <* transactionUndo
| otherwise = action
flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite