-- SPDX-FileCopyrightText: 2022 Sarah Vaupel -- -- SPDX-License-Identifier: AGPL-3.0-or-later {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE UndecidableInstances, InstanceSigs #-} module Foundation.Servant ( ServantApiDispatchUniWorX(..) , UniWorXContext , ServantHandler, ServantDB ) where import Import.Servant.NoFoundation import Foundation.DB (runSqlPoolRetry') import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun) import Foundation.Instances () import qualified Data.HashMap.Strict.InsOrd as HashMap import Network.Wai (Middleware, modifyResponse, mapResponseHeaders) import qualified Network.Wai as W import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal) import qualified Yesod.Servant as Servant import qualified Data.Text as Text import Control.Monad.Catch.Pure import Servant.Server.Internal.Delayed import Servant.Server.Internal.Router import Database.Persist.Sql (transactionUndo) import qualified Data.CaseInsensitive as CI import GHC.TypeLits (symbolVal) import Data.Typeable instance ( HasServer sub context , ToJSON restr, FromJSON restr , SBoolI (FoldRequired mods) , HasContextEntry context (Maybe (BearerToken UniWorX)) , HasContextEntry context (Maybe (Route UniWorX)) ) => HasServer (CaptureBearerRestriction' mods restr :> sub) context where type ServerT (CaptureBearerRestriction' mods restr :> sub) m = RequiredArgument mods restr -> ServerT sub m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver = route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck) where bearerCheck :: DelayedIO (RequiredArgument mods restr) bearerCheck = do let bearer :: Maybe (BearerToken UniWorX) bearer = getContextEntry context cRoute :: Maybe (Route UniWorX) cRoute = getContextEntry context noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." } noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." } noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." } exceptT delayedFailFatal return $ do cRoute' <- maybeExceptT' noRouteStored cRoute let mbRet :: Maybe (Maybe restr) mbRet = bearer <&> preview (_bearerRestrictionIx cRoute') case sbool @(FoldRequired mods) of SFalse -> return $ join mbRet STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet instance ( HasServer sub context , SBoolI (FoldRequired mods) , HasContextEntry context (Maybe (BearerToken UniWorX)) ) => HasServer (CaptureBearerToken' mods :> sub) context where type ServerT (CaptureBearerToken' mods :> sub) m = RequiredArgument mods (BearerToken UniWorX) -> ServerT sub m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver = route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck) where bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX)) bearerCheck = do let bearer :: Maybe (BearerToken UniWorX) bearer = getContextEntry context noTokenProvided :: ServerError noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." } exceptT delayedFailFatal return $ do case sbool @(FoldRequired mods) of SFalse -> return bearer STrue -> maybe (throwE noTokenProvided) return bearer instance ( HasServer sub context , HasCryptoID ciphertext plaintext (ReaderT CryptoIDKey Catch) , SBoolI (FoldLenient mods) , FromHttpApiData ciphertext , HasContextEntry context UniWorX , KnownSymbol sym ) => HasServer (CaptureCryptoID' mods ciphertext sym plaintext :> sub) context where type ServerT (CaptureCryptoID' mods ciphertext sym plaintext :> sub) m = If (FoldLenient mods) (Either String plaintext) plaintext -> ServerT sub m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver = CaptureRouter [hint] . route (Proxy @sub) context . addCapture subserver $ \txt -> case ( sbool :: SBool (FoldLenient mods) , decrypt' <$> parseUrlPiece txt ) of (SFalse, Left e ) -> delayedFail err400{ errBody = fromStrict $ encodeUtf8 e } (SFalse, Right (Left _ )) -> delayedFail err400{ errBody = "Could not decrypt CryptoID" } (SFalse, Right (Right pID)) -> return pID (STrue, join -> piece) -> return $ left unpack piece where decrypt' :: CryptoID ciphertext plaintext -> Either Text plaintext decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context hint = CaptureHint (Text.pack $ symbolVal $ Proxy @sym) (typeRep (Proxy :: Proxy sym)) -- from Servant.Server.Internal and modified for our usage type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[] type ServantHandler = ServantHandlerFor UniWorX type ServantDB = ServantDBFor UniWorX deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX instance HasServantHandlerContext UniWorX where data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX { usctxSite :: UniWorX , usctxRequest :: W.Request , usctxIsDryRun :: IsDryRun } getSCtxSite = usctxSite getSCtxRequest = usctxRequest class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where servantContext _ app _ = do isDryRun' <- MkIsDryRun <$> isDryRun restr <- maybeBearerToken cRoute <- getCurrentRoute return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor servantMiddleware _ _ ctx = appEndo . foldMap Endo $ guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader) ++ [ modifyResponse (mapResponseHeaders setDefaultHeaders) , fixTrailingSlash ] servantYesodMiddleware _ _ = return id servantServer proxy _ = servantServer' proxy setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders where defaultHeaders = HashMap.fromList [ ("X-Frame-Options", "sameorigin") , ("X-Content-Type-Options", "nosniff") , ("Vary", "Accept") , ("X-XSS-Protection", "1; mode=block") ] setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True) fixTrailingSlash :: Middleware -- ^ `servant-server` contains a special case in their implementation -- of `runRouter`, that discards trailing slashes. -- -- Because all slashes matter, this duplicates trailing slashes. fixTrailingSlash = (. fixTrailingSlash') where fixTrailingSlash' req | Just pathInfo' <- fromNullable $ W.pathInfo req , Text.null $ last pathInfo' = req { W.pathInfo = W.pathInfo req ++ [Text.empty] } | otherwise = req instance ServantPersist UniWorX where runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a runDB = runDB' callStack runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a runDB' lbl action = do $logDebugS "ServantPersist" "runDB" MkIsDryRun dryRun <- getsServantContext usctxIsDryRun let action' | dryRun = action <* transactionUndo | otherwise = action flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite