200 lines
8.6 KiB
Haskell
200 lines
8.6 KiB
Haskell
-- SPDX-FileCopyrightText: 2022 Sarah Vaupel <sarah.vaupel@ifi.lmu.de>
|
|
--
|
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
|
{-# LANGUAGE UndecidableInstances, InstanceSigs #-}
|
|
|
|
module Foundation.Servant
|
|
( ServantApiDispatchUniWorX(..)
|
|
, UniWorXContext
|
|
, ServantHandler, ServantDB
|
|
) where
|
|
|
|
import Import.Servant.NoFoundation
|
|
import Foundation.DB (runSqlPoolRetry')
|
|
import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun)
|
|
import Foundation.Instances ()
|
|
|
|
import qualified Data.HashMap.Strict.InsOrd as HashMap
|
|
|
|
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders)
|
|
import qualified Network.Wai as W
|
|
|
|
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal)
|
|
|
|
import qualified Yesod.Servant as Servant
|
|
|
|
import qualified Data.Text as Text
|
|
|
|
import Control.Monad.Catch.Pure
|
|
|
|
import Servant.Server.Internal.Delayed
|
|
import Servant.Server.Internal.Router
|
|
|
|
import Database.Persist.Sql (transactionUndo)
|
|
|
|
import qualified Data.CaseInsensitive as CI
|
|
|
|
|
|
instance ( HasServer sub context
|
|
, ToJSON restr, FromJSON restr
|
|
, SBoolI (FoldRequired mods)
|
|
, HasContextEntry context (Maybe (BearerToken UniWorX))
|
|
, HasContextEntry context (Maybe (Route UniWorX))
|
|
)
|
|
=> HasServer (CaptureBearerRestriction' mods restr :> sub) context
|
|
where
|
|
type ServerT (CaptureBearerRestriction' mods restr :> sub) m
|
|
= RequiredArgument mods restr -> ServerT sub m
|
|
|
|
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
|
|
|
|
route _ context subserver
|
|
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
|
|
where
|
|
bearerCheck :: DelayedIO (RequiredArgument mods restr)
|
|
bearerCheck = do
|
|
let bearer :: Maybe (BearerToken UniWorX)
|
|
bearer = getContextEntry context
|
|
cRoute :: Maybe (Route UniWorX)
|
|
cRoute = getContextEntry context
|
|
|
|
noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError
|
|
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
|
|
noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." }
|
|
noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." }
|
|
|
|
exceptT delayedFailFatal return $ do
|
|
cRoute' <- maybeExceptT' noRouteStored cRoute
|
|
|
|
let mbRet :: Maybe (Maybe restr)
|
|
mbRet = bearer <&> preview (_bearerRestrictionIx cRoute')
|
|
case sbool @(FoldRequired mods) of
|
|
SFalse -> return $ join mbRet
|
|
STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet
|
|
|
|
|
|
instance ( HasServer sub context
|
|
, SBoolI (FoldRequired mods)
|
|
, HasContextEntry context (Maybe (BearerToken UniWorX))
|
|
)
|
|
=> HasServer (CaptureBearerToken' mods :> sub) context
|
|
where
|
|
type ServerT (CaptureBearerToken' mods :> sub) m
|
|
= RequiredArgument mods (BearerToken UniWorX) -> ServerT sub m
|
|
|
|
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
|
|
|
|
route _ context subserver
|
|
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
|
|
where
|
|
bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX))
|
|
bearerCheck = do
|
|
let bearer :: Maybe (BearerToken UniWorX)
|
|
bearer = getContextEntry context
|
|
|
|
noTokenProvided :: ServerError
|
|
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
|
|
|
|
exceptT delayedFailFatal return $ do
|
|
case sbool @(FoldRequired mods) of
|
|
SFalse -> return bearer
|
|
STrue -> maybe (throwE noTokenProvided) return bearer
|
|
|
|
|
|
instance ( HasServer sub context
|
|
, HasCryptoID ciphertext plaintext (ReaderT CryptoIDKey Catch)
|
|
, SBoolI (FoldLenient mods)
|
|
, FromHttpApiData ciphertext
|
|
, HasContextEntry context UniWorX
|
|
) => HasServer (CaptureCryptoID' mods ciphertext sym plaintext :> sub) context where
|
|
type ServerT (CaptureCryptoID' mods ciphertext sym plaintext :> sub) m
|
|
= If (FoldLenient mods) (Either String plaintext) plaintext -> ServerT sub m
|
|
|
|
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
|
|
|
|
route _ context subserver = CaptureRouter .
|
|
route (Proxy @sub) context . addCapture subserver $ \txt -> case ( sbool :: SBool (FoldLenient mods)
|
|
, decrypt' <$> parseUrlPiece txt
|
|
) of
|
|
(SFalse, Left e ) -> delayedFail err400{ errBody = fromStrict $ encodeUtf8 e }
|
|
(SFalse, Right (Left _ )) -> delayedFail err400{ errBody = "Could not decrypt CryptoID" }
|
|
(SFalse, Right (Right pID)) -> return pID
|
|
(STrue, join -> piece) -> return $ left unpack piece
|
|
where
|
|
decrypt' :: CryptoID ciphertext plaintext -> Either Text plaintext
|
|
decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context
|
|
|
|
|
|
type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[]
|
|
type ServantHandler = ServantHandlerFor UniWorX
|
|
type ServantDB = ServantDBFor UniWorX
|
|
|
|
deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX
|
|
|
|
instance HasServantHandlerContext UniWorX where
|
|
data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX
|
|
{ usctxSite :: UniWorX
|
|
, usctxRequest :: W.Request
|
|
, usctxIsDryRun :: IsDryRun
|
|
}
|
|
getSCtxSite = usctxSite
|
|
getSCtxRequest = usctxRequest
|
|
|
|
class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where
|
|
servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler
|
|
|
|
instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where
|
|
servantContext _ app _ = do
|
|
isDryRun' <- MkIsDryRun <$> isDryRun
|
|
restr <- maybeBearerToken
|
|
cRoute <- getCurrentRoute
|
|
return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext
|
|
servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor
|
|
servantMiddleware _ _ ctx = appEndo . foldMap Endo $
|
|
guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader)
|
|
++ [ modifyResponse (mapResponseHeaders setDefaultHeaders)
|
|
, fixTrailingSlash
|
|
]
|
|
servantYesodMiddleware _ _ = return id
|
|
servantServer proxy _ = servantServer' proxy
|
|
|
|
setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders
|
|
setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders
|
|
where defaultHeaders = HashMap.fromList
|
|
[ ("X-Frame-Options", "sameorigin")
|
|
, ("X-Content-Type-Options", "nosniff")
|
|
, ("Vary", "Accept")
|
|
, ("X-XSS-Protection", "1; mode=block")
|
|
]
|
|
setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True)
|
|
|
|
fixTrailingSlash :: Middleware
|
|
-- ^ `servant-server` contains a special case in their implementation
|
|
-- of `runRouter`, that discards trailing slashes.
|
|
--
|
|
-- Because all slashes matter, this duplicates trailing slashes.
|
|
fixTrailingSlash = (. fixTrailingSlash')
|
|
where fixTrailingSlash' req
|
|
| Just pathInfo' <- fromNullable $ W.pathInfo req
|
|
, Text.null $ last pathInfo'
|
|
= req { W.pathInfo = W.pathInfo req ++ [Text.empty] }
|
|
| otherwise
|
|
= req
|
|
|
|
|
|
instance ServantPersist UniWorX where
|
|
runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
|
|
runDB = runDB' callStack
|
|
|
|
runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
|
|
runDB' lbl action = do
|
|
$logDebugS "ServantPersist" "runDB"
|
|
MkIsDryRun dryRun <- getsServantContext usctxIsDryRun
|
|
let action'
|
|
| dryRun = action <* transactionUndo
|
|
| otherwise = action
|
|
|
|
flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite
|