-- SPDX-FileCopyrightText: 2022 Sarah Vaupel -- -- SPDX-License-Identifier: AGPL-3.0-or-later {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE UndecidableInstances, InstanceSigs #-} module Foundation.Servant ( ServantApiDispatchUniWorX(..) , UniWorXContext , ServantHandler, ServantDB ) where import Import.Servant.NoFoundation import Foundation.DB (runSqlPoolRetry') import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun) import Foundation.Instances () import qualified Data.HashMap.Strict.InsOrd as HashMap import Network.Wai (Middleware, modifyResponse, mapResponseHeaders) import qualified Network.Wai as W import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal) import qualified Yesod.Servant as Servant import qualified Data.Text as Text import Control.Monad.Catch.Pure import Servant.Server.Internal.Delayed import Servant.Server.Internal.Router import Database.Persist.Sql (transactionUndo) import qualified Data.CaseInsensitive as CI instance ( HasServer sub context , ToJSON restr, FromJSON restr , SBoolI (FoldRequired mods) , HasContextEntry context (Maybe (BearerToken UniWorX)) , HasContextEntry context (Maybe (Route UniWorX)) ) => HasServer (CaptureBearerRestriction' mods restr :> sub) context where type ServerT (CaptureBearerRestriction' mods restr :> sub) m = RequiredArgument mods restr -> ServerT sub m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver = route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck) where bearerCheck :: DelayedIO (RequiredArgument mods restr) bearerCheck = do let bearer :: Maybe (BearerToken UniWorX) bearer = getContextEntry context cRoute :: Maybe (Route UniWorX) cRoute = getContextEntry context noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." } noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." } noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." } exceptT delayedFailFatal return $ do cRoute' <- maybeExceptT' noRouteStored cRoute let mbRet :: Maybe (Maybe restr) mbRet = bearer <&> preview (_bearerRestrictionIx cRoute') case sbool @(FoldRequired mods) of SFalse -> return $ join mbRet STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet instance ( HasServer sub context , SBoolI (FoldRequired mods) , HasContextEntry context (Maybe (BearerToken UniWorX)) ) => HasServer (CaptureBearerToken' mods :> sub) context where type ServerT (CaptureBearerToken' mods :> sub) m = RequiredArgument mods (BearerToken UniWorX) -> ServerT sub m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver = route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck) where bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX)) bearerCheck = do let bearer :: Maybe (BearerToken UniWorX) bearer = getContextEntry context noTokenProvided :: ServerError noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." } exceptT delayedFailFatal return $ do case sbool @(FoldRequired mods) of SFalse -> return bearer STrue -> maybe (throwE noTokenProvided) return bearer instance ( HasServer sub context , HasCryptoID ciphertext plaintext (ReaderT CryptoIDKey Catch) , SBoolI (FoldLenient mods) , FromHttpApiData ciphertext , HasContextEntry context UniWorX ) => HasServer (CaptureCryptoID' mods ciphertext sym plaintext :> sub) context where type ServerT (CaptureCryptoID' mods ciphertext sym plaintext :> sub) m = If (FoldLenient mods) (Either String plaintext) plaintext -> ServerT sub m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver = CaptureRouter . route (Proxy @sub) context . addCapture subserver $ \txt -> case ( sbool :: SBool (FoldLenient mods) , decrypt' <$> parseUrlPiece txt ) of (SFalse, Left e ) -> delayedFail err400{ errBody = fromStrict $ encodeUtf8 e } (SFalse, Right (Left _ )) -> delayedFail err400{ errBody = "Could not decrypt CryptoID" } (SFalse, Right (Right pID)) -> return pID (STrue, join -> piece) -> return $ left unpack piece where decrypt' :: CryptoID ciphertext plaintext -> Either Text plaintext decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[] type ServantHandler = ServantHandlerFor UniWorX type ServantDB = ServantDBFor UniWorX deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX instance HasServantHandlerContext UniWorX where data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX { usctxSite :: UniWorX , usctxRequest :: W.Request , usctxIsDryRun :: IsDryRun } getSCtxSite = usctxSite getSCtxRequest = usctxRequest class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where servantContext _ app _ = do isDryRun' <- MkIsDryRun <$> isDryRun restr <- maybeBearerToken cRoute <- getCurrentRoute return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor servantMiddleware _ _ ctx = appEndo . foldMap Endo $ guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader) ++ [ modifyResponse (mapResponseHeaders setDefaultHeaders) , fixTrailingSlash ] servantYesodMiddleware _ _ = return id servantServer proxy _ = servantServer' proxy setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders where defaultHeaders = HashMap.fromList [ ("X-Frame-Options", "sameorigin") , ("X-Content-Type-Options", "nosniff") , ("Vary", "Accept") , ("X-XSS-Protection", "1; mode=block") ] setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True) fixTrailingSlash :: Middleware -- ^ `servant-server` contains a special case in their implementation -- of `runRouter`, that discards trailing slashes. -- -- Because all slashes matter, this duplicates trailing slashes. fixTrailingSlash = (. fixTrailingSlash') where fixTrailingSlash' req | Just pathInfo' <- fromNullable $ W.pathInfo req , Text.null $ last pathInfo' = req { W.pathInfo = W.pathInfo req ++ [Text.empty] } | otherwise = req instance ServantPersist UniWorX where runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a runDB = runDB' callStack runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a runDB' lbl action = do $logDebugS "ServantPersist" "runDB" MkIsDryRun dryRun <- getsServantContext usctxIsDryRun let action' | dryRun = action <* transactionUndo | otherwise = action flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite