feat(servant): dry-run support

This commit is contained in:
Gregor Kleen 2022-01-26 22:09:03 +01:00
parent 605b7758e6
commit 47df8a312f
16 changed files with 585 additions and 170 deletions

View File

@ -346,6 +346,7 @@ tests:
- quickcheck-io
- network-arbitrary
- lens-properties
- http-media
ghc-options:
- -fno-warn-orphans
- -threaded -rtsopts "-with-rtsopts=-N -T"

View File

@ -9,6 +9,7 @@ module Foundation.Authorization
, wouldHaveReadAccessToIff, wouldHaveWriteAccessToIff
, AuthContext(..), getAuthContext
, isDryRun, isDryRunDB
, IsDryRun(..)
, maybeBearerToken, requireBearerToken
, requireCurrentBearerRestrictions, maybeCurrentBearerRestrictions
, BearerAuthSite, MonadAP
@ -276,7 +277,9 @@ getAuthContext = liftHandler $ do
return authCtx
newtype IsDryRun = MkIsDryRun { unIsDryRun :: Bool }
deriving (Eq, Ord, Read, Show, Generic, Typeable)
deriving stock (Read, Show, Generic, Typeable)
deriving newtype (Eq, Ord)
deriving (Semigroup, Monoid) via Any
isDryRun :: ( HasCallStack
, BearerAuthSite UniWorX
@ -296,6 +299,7 @@ isDryRunDB :: forall m backend m'.
isDryRunDB = fmap unIsDryRun . cached . fmap MkIsDryRun $ orM
[ hasGlobalPostParam PostDryRun
, hasGlobalGetParam GetDryRun
, hasCustomHeader HeaderDryRun
, and2M bearerDryRun bearerRequired
]
where

View File

@ -9,19 +9,15 @@ module Foundation.Servant
import Import.Servant.NoFoundation
import Foundation.DB (runSqlPoolRetry')
import Foundation.Authorization (maybeBearerToken)
import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun)
import Foundation.Instances ()
import qualified Data.HashMap.Strict.InsOrd as HashMap
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders, vault)
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders)
import qualified Network.Wai as W
import qualified Data.Vault.Lazy as Vault
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal, withRequest)
import System.IO.Unsafe (unsafePerformIO)
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal)
import qualified Yesod.Servant as Servant
@ -32,21 +28,16 @@ import Control.Monad.Catch.Pure
import Servant.Server.Internal.Delayed
import Servant.Server.Internal.Router
-- import Database.Persist.Sql (transactionUndo)
import Database.Persist.Sql (transactionUndo)
waiBearerKey :: Vault.Key (Maybe (BearerToken UniWorX))
waiBearerKey = unsafePerformIO Vault.newKey
{-# NOINLINE waiBearerKey #-}
waiRouteKey :: Vault.Key (Route UniWorX)
waiRouteKey = unsafePerformIO Vault.newKey
{-# NOINLINE waiRouteKey #-}
import qualified Data.CaseInsensitive as CI
instance ( HasServer sub context
, ToJSON restr, FromJSON restr
, SBoolI (FoldRequired mods)
, HasContextEntry context (Maybe (BearerToken UniWorX))
, HasContextEntry context (Maybe (Route UniWorX))
)
=> HasServer (CaptureBearerRestriction' mods restr :> sub) context
where
@ -56,25 +47,25 @@ instance ( HasServer sub context
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver
= route (Proxy @sub) context (subserver `addAuthCheck` withRequest bearerCheck)
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
where
bearerCheck :: W.Request -> DelayedIO (RequiredArgument mods restr)
bearerCheck req = do
let bearer = Vault.lookup waiBearerKey $ vault req
cRoute = Vault.lookup waiRouteKey $ vault req
bearerCheck :: DelayedIO (RequiredArgument mods restr)
bearerCheck = do
let bearer :: Maybe (BearerToken UniWorX)
bearer = getContextEntry context
cRoute :: Maybe (Route UniWorX)
cRoute = getContextEntry context
noRouteStored, noTokenStored, noTokenProvided, noRestrictionProvided :: ServerError
noTokenStored = err500 { errBody = "servantYesodMiddleware did not store bearer token in WAI vault." }
noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." }
noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." }
exceptT delayedFailFatal return $ do
bearer' <- maybeExceptT' noTokenStored bearer
cRoute' <- maybeExceptT' noRouteStored cRoute
let mbRet :: Maybe (Maybe restr)
mbRet = bearer' <&> preview (_bearerRestrictionIx cRoute')
mbRet = bearer <&> preview (_bearerRestrictionIx cRoute')
case sbool @(FoldRequired mods) of
SFalse -> return $ join mbRet
STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet
@ -82,6 +73,7 @@ instance ( HasServer sub context
instance ( HasServer sub context
, SBoolI (FoldRequired mods)
, HasContextEntry context (Maybe (BearerToken UniWorX))
)
=> HasServer (CaptureBearerToken' mods :> sub) context
where
@ -91,21 +83,20 @@ instance ( HasServer sub context
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
route _ context subserver
= route (Proxy @sub) context (subserver `addAuthCheck` withRequest bearerCheck)
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
where
bearerCheck :: W.Request -> DelayedIO (RequiredArgument mods (BearerToken UniWorX))
bearerCheck req = do
let bearer = Vault.lookup waiBearerKey $ vault req
bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX))
bearerCheck = do
let bearer :: Maybe (BearerToken UniWorX)
bearer = getContextEntry context
noTokenStored, noTokenProvided :: ServerError
noTokenStored = err500 { errBody = "servantYesodMiddleware did not store bearer token in WAI vault." }
noTokenProvided :: ServerError
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
exceptT delayedFailFatal return $ do
bearer' <- maybeExceptT' noTokenStored bearer
case sbool @(FoldRequired mods) of
SFalse -> return bearer'
STrue -> maybe (throwE noTokenProvided) return bearer'
SFalse -> return bearer
STrue -> maybe (throwE noTokenProvided) return bearer
instance ( HasServer sub context
@ -132,23 +123,40 @@ instance ( HasServer sub context
decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context
type UniWorXContext = UniWorX ': '[]
type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[]
type ServantHandler = ServantHandlerFor UniWorX
type ServantDB = ServantDBFor UniWorX
deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX
instance HasServantHandlerContext UniWorX where
data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX
{ usctxSite :: UniWorX
, usctxRequest :: W.Request
, usctxIsDryRun :: IsDryRun
}
getSCtxSite = usctxSite
getSCtxRequest = usctxRequest
class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where
servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler
instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where
servantContext _ app _ = return $ app :. EmptyContext
servantHoist _ sctxSite sctxRequest _ = ($ ServantHandlerContextFor{..}) . unServantHandlerFor
servantMiddleware _ _ _ = modifyResponse (mapResponseHeaders setDefaultHeaders) . fixTrailingSlash
servantYesodMiddleware _ _ = appEndo <$> foldMapM (fmap Endo) [storeBearerToken, storeCurrentRoute]
servantContext _ app _ = do
isDryRun' <- MkIsDryRun <$> isDryRun
restr <- maybeBearerToken
cRoute <- getCurrentRoute
return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext
servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor
servantMiddleware _ _ ctx = appEndo . foldMap Endo $
guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader)
++ [ modifyResponse (mapResponseHeaders setDefaultHeaders)
, fixTrailingSlash
]
servantYesodMiddleware _ _ = return id
servantServer proxy _ = servantServer' proxy
setDefaultHeaders :: ResponseHeaders -> ResponseHeaders
setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders
setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders
where defaultHeaders = HashMap.fromList
[ ("X-Frame-Options", "sameorigin")
@ -156,6 +164,7 @@ setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defau
, ("Vary", "Accept")
, ("X-XSS-Protection", "1; mode=block")
]
setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True)
fixTrailingSlash :: Middleware
-- ^ `servant-server` contains a special case in their implementation
@ -170,17 +179,6 @@ fixTrailingSlash = (. fixTrailingSlash')
| otherwise
= req
storeBearerToken, storeCurrentRoute :: HandlerFor UniWorX Middleware
storeBearerToken = do
restr <- maybeBearerToken
return $ \app req -> app req{ vault = Vault.insert waiBearerKey restr $ vault req }
storeCurrentRoute = do
cRoute <- getCurrentRoute
$logDebugS "storeCurrentRoute" $ tshow cRoute
return $ \app req -> app req{ vault = maybe id (Vault.insert waiRouteKey) cRoute $ vault req }
instance ServantPersist UniWorX where
runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
@ -189,9 +187,9 @@ instance ServantPersist UniWorX where
runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
runDB' lbl action = do
$logDebugS "ServantPersist" "runDB"
-- let action' = do
-- dryRun <- isDryRunDB
-- if | dryRun -> action <* transactionUndo
-- | otherwise -> action
MkIsDryRun dryRun <- getsServantContext usctxIsDryRun
let action'
| dryRun = action <* transactionUndo
| otherwise = action
flip (runSqlPoolRetry' action) lbl . appConnPool =<< getSite
flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite

View File

@ -5,12 +5,14 @@ module Foundation.Servant.Types
, CaptureBearerToken, CaptureBearerToken'
, CaptureCryptoID', CaptureCryptoID, CaptureCryptoUUID, CaptureCryptoFileName
, ApiVersion, apiVersionToSemVer, matchesApiVersion
, BearerAuth, SessionAuth
) where
import ClassyPrelude
import ClassyPrelude hiding (fromList)
import Data.Proxy
import Servant.API
import Servant.API.Modifiers (FoldRequired)
import Servant.API.Description
import Servant.Swagger
import Servant.Docs
@ -21,6 +23,13 @@ import Servant.Server.Internal.Delayed
import Servant.Server.Internal.ErrorFormatter
-- import Servant.Server.Internal.DelayedIO
import Servant.Client.Core.RunClient (RunClient)
import Servant.Client.Core.HasClient
import qualified Servant.Client.Core.Request as Servant (Request)
import qualified Servant.Client.Core.Request as Request
import Jose.Jwt (Jwt(..))
import Network.Wai (mapResponseHeaders, requestHeaders)
import Control.Lens hiding (Context)
@ -31,8 +40,9 @@ import Data.CryptoID.Class.ImplicitNamespace
import Data.CryptoID.Instances ()
import GHC.TypeLits
import GHC.Exts (IsList(..))
import Data.Swagger (ToParamSchema)
import Data.Swagger hiding (version)
import Data.Kind (Type)
@ -114,6 +124,11 @@ instance HasDocs sub => HasDocs (CaptureBearerToken' mods :> sub) where
instance (ToCapture (Capture sym ciphertext), KnownSymbol sym, HasDocs sub) => HasDocs (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where
docsFor _ = docsFor $ Proxy @(Capture' mods sym ciphertext :> sub)
instance (RunClient m, HasClient m (Capture' mods sym (CryptoID ciphertext plaintext) :> sub)) => HasClient m (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where
type Client m (CaptureCryptoID' mods ciphertext sym plaintext :> sub) = Client m (Capture' mods sym (CryptoID ciphertext plaintext) :> sub)
clientWithRoute pm _ = clientWithRoute pm $ Proxy @(Capture' mods sym (CryptoID ciphertext plaintext) :> sub)
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(Capture' mods sym (CryptoID ciphertext plaintext) :> sub)
type family ApiVersionSub major minor patch sup sub where
ApiVersionSub major minor patch (ApiVersion major' minor' patch') sub = ApiVersion major' minor' patch' :> sub
@ -143,8 +158,30 @@ instance ( HasServer (ApiVersion major minor patch :> a) context
choice' = case (sbool :: SBool (IsLT (CmpVersion (FinalApiVersion (ApiVersion major minor patch :> a)) (FinalApiVersion (ApiVersion major minor patch :> b))))) of
STrue -> flip choice
SFalse -> choice
instance (RunClient m, HasClient m (ApiVersionSub major minor patch sup sub)) => HasClient m (ApiVersion major minor patch :> ((sup :: Type) :> sub)) where
type Client m (ApiVersion major minor patch :> (sup :> sub)) = Client m (ApiVersionSub major minor patch sup sub)
clientWithRoute pm _ = clientWithRoute pm $ Proxy @(ApiVersionSub major minor patch sup sub)
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(ApiVersionSub major minor patch sup sub)
instance (RunClient m, HasClient m (sup :> (ApiVersion major minor patch :> sub))) => HasClient m (ApiVersion major minor patch :> ((sup :: Symbol) :> sub)) where
type Client m (ApiVersion major minor patch :> (sup :> sub)) = Client m (sup :> (ApiVersion major minor patch :> sub))
clientWithRoute pm _ = clientWithRoute pm $ Proxy @(sup :> (ApiVersion major minor patch :> sub))
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(sup :> (ApiVersion major minor patch :> sub))
instance ( HasClient m (ApiVersion major minor patch :> a)
, HasClient m (ApiVersion major minor patch :> b)
) => HasClient m (ApiVersion major minor patch :> (a :<|> b)) where
type Client m (ApiVersion major minor patch :> (a :<|> b)) = Client m (ApiVersion major minor patch :> a) :<|> Client m (ApiVersion major minor patch :> b)
clientWithRoute pm _ req = clientWithRoute pm (Proxy @(ApiVersion major minor patch :> a)) req
:<|> clientWithRoute pm (Proxy @(ApiVersion major minor patch :> b)) req
hoistClientMonad pm _ f (ca :<|> cb) = hoistClientMonad pm (Proxy @(ApiVersion major minor patch :> a)) f ca
:<|> hoistClientMonad pm (Proxy @(ApiVersion major minor patch :> b)) f cb
versionRequestHeaderName :: CI ByteString
versionRequestHeaderName = "Accept-API-Version"
routeWithApiVersion :: forall api context env major minor patch.
( HasServer api context
, KnownNat major, KnownNat minor, KnownNat patch
@ -168,7 +205,6 @@ routeWithApiVersion _ _ context subserver = RawRouter $ \env req ((. addVersion)
version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch)
versionHeaderName = "API-Version"
versionRequestHeaderName = "Accept-API-Version"
versionHeader = encodeUtf8 $ SemVer.toText version
notFound = notFoundErrorFormatter . getContextEntry $ mkContextWithErrorFormatter context
@ -194,6 +230,26 @@ instance ( HasServer (NoContentVerb method) context
route _ = routeWithApiVersion (Proxy @(ApiVersion major minor patch)) (Proxy @(NoContentVerb method))
semVerCompatibleTo :: SemVer.Version -> SemVer.Constraint
semVerCompatibleTo v = SemVer.Constraint.CAnd (SemVer.Constraint.CGtEq v) (SemVer.Constraint.CLt $ SemVer.incrementMajor v)
instance ( HasClient m (Verb method statusCode contentTypes a)
, KnownNat major, KnownNat minor, KnownNat patch
) => HasClient m (ApiVersion major minor patch :> Verb method statusCode contentTypes a) where
type Client m (ApiVersion major minor patch :> Verb method statusCode contentTypes a) = Client m (Verb method statusCode contentTypes a)
clientWithRoute pm _ = clientWithRoute pm (Proxy @(Verb method statusCode contentTypes a)) . Request.addHeader versionRequestHeaderName (semVerCompatibleTo version)
where version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch)
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(Verb method statusCode contentTypes a)
instance ( HasClient m (NoContentVerb method)
, KnownNat major, KnownNat minor, KnownNat patch
) => HasClient m (ApiVersion major minor patch :> NoContentVerb method) where
type Client m (ApiVersion major minor patch :> NoContentVerb method) = Client m (NoContentVerb method)
clientWithRoute pm _ = clientWithRoute pm (Proxy @(NoContentVerb method)) . Request.addHeader versionRequestHeaderName (semVerCompatibleTo version)
where version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch)
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(NoContentVerb method)
instance ( HasDocs (ApiVersionSub major minor patch sup sub)
) => HasDocs (ApiVersion major minor patch :> ((sup :: Type) :> sub)) where
docsFor _ = docsFor $ Proxy @(ApiVersionSub major minor patch sup sub)
@ -263,3 +319,105 @@ type family IsLT x where
type instance IsElem' sa (CaptureCryptoID' mods ciphertext sym plaintext :> sb) = IsElem sa (Capture' mods sym (CryptoID ciphertext plaintext) :> sb)
type instance IsElem' sa (ApiVersion major minor patch :> sb) = IsElem sa sb
type family StripBearer api where
StripBearer (CaptureBearerRestriction' mods restr :> sub) = sub
StripBearer (CaptureBearerToken' mods :> sub) = sub
StripBearer (BearerAuth :> sub) = sub
StripBearer (sup :> sub) = sup :> StripBearer sub
StripBearer (a :<|> b) = StripBearer a :<|> StripBearer b
StripBearer (Verb method statusCode contentTypes a) = Verb method statusCode contentTypes a
StripBearer (NoContentVerb method) = NoContentVerb method
type family BearerRequired api where
BearerRequired (CaptureBearerRestriction' mods restr :> sub) = OrBool (FoldRequired mods) (BearerRequired sub)
BearerRequired (CaptureBearerToken' mods :> sub) = OrBool (FoldRequired mods) (BearerRequired sub)
BearerRequired (BearerAuth :> sub) = 'True
BearerRequired (sup :> sub) = BearerRequired sub
BearerRequired (a :<|> b) = OrBool (BearerRequired a) (BearerRequired b)
BearerRequired (Verb method statusCode contentTypes a) = 'False
BearerRequired (NoContentVerb method) = 'False
type family OrBool a b where
OrBool 'False 'False = 'False
OrBool a b = 'True
maybeWithJwt :: forall (a :: Bool). SBoolI a => Proxy a -> If a Jwt (Maybe Jwt) -> Servant.Request -> Servant.Request
maybeWithJwt _ mparam = case (sbool :: SBool a, mparam) of
(STrue, jwt) -> add jwt
(SFalse, mJwt) -> maybe id add mJwt
where add (Jwt jwt) = Request.addHeader "Authorization" . decodeUtf8 $ "Bearer " <> jwt
instance ( HasClient m (StripBearer sub)
, RunClient m
, SBoolI (BearerRequired (CaptureBearerRestriction' mods restr :> sub))
) => HasClient m (CaptureBearerRestriction' mods restr :> sub) where
type Client m (CaptureBearerRestriction' mods restr :> sub) = If (BearerRequired (CaptureBearerRestriction' mods restr :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub)
clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (CaptureBearerRestriction' mods restr :> sub))) mparam req
hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl
instance ( HasClient m (StripBearer sub)
, RunClient m
, SBoolI (BearerRequired (CaptureBearerToken' mods :> sub))
) => HasClient m (CaptureBearerToken' mods :> sub) where
type Client m (CaptureBearerToken' mods :> sub) = If (BearerRequired (CaptureBearerToken' mods :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub)
clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (CaptureBearerToken' mods :> sub))) mparam req
hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl
instance ( HasClient m (StripBearer sub)
, RunClient m
, SBoolI (BearerRequired (BearerAuth :> sub))
) => HasClient m (BearerAuth :> sub) where
type Client m (BearerAuth :> sub) = If (BearerRequired (BearerAuth :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub)
clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (BearerAuth :> sub))) mparam req
hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl
data BearerAuth
data SessionAuth
instance HasSwagger sub => HasSwagger (BearerAuth :> sub) where
toSwagger _ = toSwagger (Proxy @sub)
& securityDefinitions <>~ SecurityDefinitions (fromList [(defnKey, defn)])
& allOperations . security <>~ [SecurityRequirement $ fromList [(defnKey, [])]]
where defnKey :: Text
defnKey = "bearer"
defn = SecurityScheme
{ _securitySchemeType
= SecuritySchemeApiKey ApiKeyParams
{ _apiKeyName = "Authorization"
, _apiKeyIn = ApiKeyHeader
}
, _securitySchemeDescription = Just
"JSON Web Token-based API key"
}
instance HasSwagger sub => HasSwagger (SessionAuth :> sub) where
toSwagger _ = toSwagger (Proxy @sub)
& allOperations . security <>~ [SecurityRequirement mempty]
-- We do not expect API clients to be able/willing to conform with
-- our CSRF mitigation, so we mark routes that require it as
-- having unfullfillable security requirements
instance HasLink sub => HasLink (BearerAuth :> sub) where
type MkLink (BearerAuth :> sub) a = MkLink sub a
toLink toA _ = toLink toA (Proxy @sub)
instance HasLink sub => HasLink (SessionAuth :> sub) where
type MkLink (SessionAuth :> sub) a = MkLink sub a
toLink toA _ = toLink toA (Proxy @sub)
instance HasDocs sub => HasDocs (BearerAuth :> sub) where
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
where action' = action & authInfo %~ (|> authInfo')
authInfo' = DocAuthentication
""
"A JSON Web Token-based API key"
instance HasDocs sub => HasDocs (SessionAuth :> sub) where
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
where action' = action & authInfo %~ (|> authInfo')
authInfo' = DocAuthentication
"When a web session is used for authorization, CSRF-mitigation measures must be observed."
"An active web session identifying the user as one with sufficient authorization"

View File

@ -40,7 +40,7 @@ import Yesod.Core.Types as Import (loggerSet)
import Yesod.Default.Config2 as Import
import Yesod.Core.Types.Instances as Import
import Yesod.Servant as Import
hiding ( MonadHandler(..), HasRoute(..)
hiding ( MonadHandler(..), HasRoute(..), MonadRequest(..)
, runDB, defaultRunDB
)
import Servant.Docs as Import
@ -210,6 +210,7 @@ import Data.MonoTraversable.Instances as Import ()
import Servant.Client.Core.BaseUrl.Instances as Import ()
import Control.Monad.Trans.Except.Instances as Import ()
import Servant.Server.Instances as Import ()
import Servant.Docs.Internal.Pretty.Instances as Import ()
import Network.URI.Instances as Import ()
import Data.HashSet.Instances as Import ()
import Web.Cookie.Instances as Import ()

View File

@ -14,6 +14,7 @@ import Import.NoFoundation as Import hiding
, MonadHandler(..), HasRoute(..), liftHandler
, encrypt, decrypt
, Unique, Fragment(..), respond
, getRequest
)
import Yesod.Servant as Import

View File

@ -0,0 +1,14 @@
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Servant.Docs.Internal.Pretty.Instances () where
import ClassyPrelude
import Servant.Docs.Internal.Pretty
import Servant.API.ContentTypes
import Data.Proxy
instance MimeUnrender JSON a => MimeUnrender PrettyJSON a where
mimeUnrender _ = mimeUnrender $ Proxy @JSON

View File

@ -15,17 +15,22 @@ import Jose.Jwk (JwkSet(..))
{-# ANN module ("HLint: ignore Use newtype instead of data" :: String) #-}
type ExternalApisListR = Get '[PrettyJSON] ExternalApisList
type ExternalApisCreateR = CaptureBearerRestriction' '[Optional] ExternalApiCreationRestrictions
type ExternalApisListR = ApiVersion 1 0 0
:> Get '[PrettyJSON] ExternalApisList
type ExternalApisCreateR = ApiVersion 1 0 0
:> CaptureBearerRestriction' '[Optional] ExternalApiCreationRestrictions
:> CaptureBearerToken
:> ReqBody '[JSON] ExternalApiCreationRequest
:> PostCreated '[PrettyJSON] (Headers '[Header "Location" URI] ExternalApiCreationResponse)
type ExternalApisPongR = CaptureCryptoUUID "external-api" ExternalApiId
type ExternalApisPongR = ApiVersion 1 0 0
:> CaptureCryptoUUID "external-api" ExternalApiId
:> "pong"
:> Post '[PrettyJSON] ExternalApiPongResponse
type ExternalApisInfoR = CaptureCryptoUUID "external-api" ExternalApiId
type ExternalApisInfoR = ApiVersion 1 0 0
:> CaptureCryptoUUID "external-api" ExternalApiId
:> Get '[PrettyJSON] ExternalApiInfo
type ExternalApisDeleteR = CaptureCryptoUUID "external-api" ExternalApiId
type ExternalApisDeleteR = ApiVersion 1 0 0
:> CaptureCryptoUUID "external-api" ExternalApiId
:> DeleteNoContent
data ExternalApis mode = ExternalApis
@ -37,7 +42,7 @@ data ExternalApis mode = ExternalApis
} deriving (Generic)
type ServantApiExternalApis = ServantApi ExternalApis
type instance ServantApiUnproxy ExternalApis = ApiVersion 1 0 0 :> ToServantApi ExternalApis
type instance ServantApiUnproxy ExternalApis = ToServantApi ExternalApis
instance ToCapture (Capture "external-api" UUID) where
@ -122,7 +127,9 @@ data ExternalApiInfo = ExternalApiInfo
instance ToJSON ExternalApiInfo where
toJSON ExternalApiInfo{..} = object $ maybe id ((:) . ("ident" .=)) eaiIdent
[ "token-authority" .= foldMap (HashSet.singleton . either id toJSON) eaiTokenAuthority
[ "token-authority" .= case HashSet.toList eaiTokenAuthority of
[x] -> either id toJSON x
_ -> toJSON $ foldMap (HashSet.singleton . either id toJSON) eaiTokenAuthority
, "token-issued" .= eaiTokenIssued
, "token-expires-at" .= eaiTokenExpiresAt
, "token-starts-at" .= eaiTokenStartsAt
@ -134,11 +141,11 @@ instance ToJSON ExternalApiInfo where
instance FromJSON ExternalApiInfo where
parseJSON = withObject "ExternalApiInfo" $ \o -> do
eaiIdent <- o .:? "token-authority"
eaiIdent <- o .:? "ident"
eaiTokenAuthority <- asum
[ HashSet.singleton . Right <$> o .: "authority"
, (o .: "authority" :: _ (HashSet Value)) >>= foldMapM (\v' -> fmap HashSet.singleton $ (Right <$> parseJSON v') <|> return (Left v'))
, HashSet.singleton . Left <$> o .: "authority"
[ HashSet.singleton . Right <$> o .: "token-authority"
, (o .: "token-authority" :: _ (HashSet Value)) >>= foldMapM (\v' -> fmap HashSet.singleton $ (Right <$> parseJSON v') <|> return (Left v'))
, HashSet.singleton . Left <$> o .: "token-authority"
]
eaiTokenIssued <- o .: "token-issued"
eaiTokenExpiresAt <- o .: "token-expires-at"

View File

@ -113,7 +113,9 @@ import Data.Binary (Binary)
import qualified Data.Binary as Binary
import Network.Wai (requestMethod)
import Network.HTTP.Types.Header
import Network.HTTP.Types.Header as Wai
import Web.HttpApiData
import Data.Time.Clock
@ -1143,6 +1145,9 @@ addCustomHeader, replaceOrAddCustomHeader :: (MonadHandler m, PathPiece payload)
addCustomHeader ident payload = addHeader (toPathPiece ident) (toPathPiece payload)
replaceOrAddCustomHeader ident payload = replaceOrAddHeader (toPathPiece ident) (toPathPiece payload)
waiCustomHeader :: ToHttpApiData payload => CustomHeader -> payload -> Wai.Header
waiCustomHeader ident payload = (CI.mk . encodeUtf8 $ toPathPiece ident, toHeader payload)
------------------
-- HTTP Headers --
------------------

View File

@ -8,8 +8,8 @@ module Yesod.Servant
, ServantApiDispatch(..)
, servantApiLink
, ServantHandlerFor(..)
, ServantHandlerContextFor(..), getServantContext, getsServantContext, getYesodApproot, renderRouteAbsolute
, MonadServantHandler(..), MonadHandler(..), MonadSite(..)
, HasServantHandlerContext(..), getServantContext, getsServantContext, getYesodApproot, renderRouteAbsolute, servantApiBaseUrl
, MonadServantHandler(..), MonadHandler(..), MonadSite(..), MonadRequest(..)
, ServantDBFor, ServantPersist(..), defaultRunDB
, ServantLog(..), ServantLogYesod(..)
, mkYesodApi
@ -45,6 +45,8 @@ import Servant.API
import Servant.Server hiding (route)
import Servant.Server.Instances ()
import Servant.Client.Core.BaseUrl
import Data.Proxy
import Network.Wai (Request, Middleware)
@ -59,13 +61,10 @@ import Control.Monad.Fail (MonadFail(..))
import Data.Data (Data)
import Data.Kind (Type)
import GHC.Exts (IsList(..), Constraint)
import GHC.Exts (Constraint)
import Servant.Swagger
import Data.Swagger
import Servant.Docs
import qualified Data.Set as Set
import Network.HTTP.Types.Status
@ -138,75 +137,92 @@ instance (Typeable m, Typeable k, Typeable status, Typeable fr, Typeable ct, Typ
instance HasRoute sub => HasRoute (HttpVersion :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(HttpVersion :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance HasRoute sub => HasRoute (Vault :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Vault :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, KnownSymbol realm, Typeable a) => HasRoute (BasicAuth realm a :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(BasicAuth realm a :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, KnownSymbol s) => HasRoute (Description s :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Description s :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, KnownSymbol s) => HasRoute (Summary s :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Summary s :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, Typeable tag, Typeable k) => HasRoute (AuthProtect (tag :: k) :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(AuthProtect tag :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance HasRoute sub => HasRoute (IsSecure :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(IsSecure :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance HasRoute sub => HasRoute (RemoteHost :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(RemoteHost :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, Typeable mods, Typeable restr) => HasRoute (CaptureBearerRestriction' mods restr :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(CaptureBearerRestriction' mods restr :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, Typeable mods) => HasRoute (CaptureBearerToken' mods :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(CaptureBearerToken' mods :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (KnownSymbol sym, HasRoute sub, HasLink sub) => HasRoute (sym :> sub) where
parseServantRoute (p : ps, qs)
| p == escapedSymbol (Proxy @sym)
= parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(sym :> endpoint)) f (escapedSymbol (Proxy @sym) : ps') qs'
ServantApiBaseRoute -> ServantApiBaseRoute
parseServantRoute _ = Nothing
instance (HasRoute a, HasRoute b) => HasRoute (a :<|> b) where
parseServantRoute args = asum
[ parseServantRoute @a @(ServantApiDirect a) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @endpoint) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
, parseServantRoute @b @(ServantApiDirect b) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @endpoint) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
]
instance (HasRoute sub, Typeable mods, Typeable ct, Typeable a) => HasRoute (ReqBody' mods ct a :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(ReqBody' mods ct a :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, Typeable mods, Typeable framing, Typeable ct, Typeable a) => HasRoute (StreamBody' mods framing ct a :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(StreamBody' mods framing ct a :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, KnownSymbol sym, Typeable mods, Typeable a) => HasRoute (Header' mods sym (a :: Type) :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Header' mods sym a :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable v, ToHttpApiDataInjective v, FromHttpApiData v) => HasRoute (Capture' mods sym (v :: Type) :> sub) where
parseServantRoute (p : ps, qs)
| Right v <- parseUrlPiece @v p
= parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(Capture' mods sym v :> endpoint)) (f . ($ v)) (toUrlPieceInjective v : ps') qs'
ServantApiBaseRoute -> ServantApiBaseRoute
parseServantRoute _ = Nothing
instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable plaintext, ToHttpApiDataInjective ciphertext, FromHttpApiData ciphertext, Typeable ciphertext) => HasRoute (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where
@ -214,11 +230,13 @@ instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable plaintext, ToHt
| Right v <- parseUrlPiece @(CryptoID ciphertext plaintext) p
= parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(CaptureCryptoID' mods ciphertext sym plaintext :> endpoint)) (f . ($ v)) (toUrlPieceInjective v : ps') qs'
ServantApiBaseRoute -> ServantApiBaseRoute
parseServantRoute _ = Nothing
instance (HasRoute sub, KnownNat major, KnownNat minor, KnownNat patch) => HasRoute (ApiVersion major minor patch :> sub) where
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(ApiVersion major minor patch :> endpoint)) f ps qs
ServantApiBaseRoute -> ServantApiBaseRoute
data ServantApi (proxy :: k) = ServantApi
@ -249,45 +267,56 @@ instance HasRoute (ServantApiUnproxy' proxy) => RenderRoute (ServantApi proxy) w
(Proxy endpoint)
(forall a. MkLink endpoint a -> a)
[Text] (HashMap Text [Text])
| ServantApiBaseRoute
renderRoute (ServantApiRoute (_ :: Proxy endpoint) f _ _) = f $ safeLink' renderServantRoute (Proxy @(ServantApiUnproxy' proxy)) (Proxy @endpoint)
renderRoute ServantApiBaseRoute = mempty
instance HasRoute (ServantApiUnproxy' proxy) => Eq (Route (ServantApi proxy)) where
(ServantApiRoute (_ :: Proxy endpoint) _ ps qs) == (ServantApiRoute (_ :: Proxy endpoint') _ ps' qs')
= case eqT @endpoint @endpoint' of
Just Refl -> ps == ps' && qs == qs'
Nothing -> False
ServantApiBaseRoute == ServantApiBaseRoute = True
_ == _ = False
instance HasRoute (ServantApiUnproxy' proxy) => Ord (Route (ServantApi proxy)) where
compare (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) (ServantApiRoute (_ :: Proxy endpoint') _ ps' qs')
= case eqT @endpoint @endpoint' of
Just Refl -> compare ps ps' <> compare qs qs'
Nothing -> typeRep (Proxy @endpoint) `compare` typeRep (Proxy @endpoint')
compare ServantApiBaseRoute ServantApiBaseRoute = EQ
compare ServantApiBaseRoute _ = LT
compare _ ServantApiBaseRoute = GT
instance HasRoute (ServantApiUnproxy' proxy) => Hashable (Route (ServantApi proxy)) where
hashWithSalt salt (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = salt `hashWithSalt` typeRep (Proxy @endpoint) `hashWithSalt` ps `hashWithSalt` qs
hashWithSalt salt (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = salt `hashWithSalt` (0 :: Int) `hashWithSalt` typeRep (Proxy @endpoint) `hashWithSalt` ps `hashWithSalt` qs
hashWithSalt salt ServantApiBaseRoute = salt `hashWithSalt` (1 :: Int)
instance HasRoute (ServantApiUnproxy' proxy) => Read (Route (ServantApi proxy)) where
readPrec = readP_to_Prec $ \d -> do
when (d > 10) . void $ R.char '('
R.skipSpaces
void $ R.string "ServantApiRoute "
R.skipSpaces
void $ R.string "_ "
R.skipSpaces
asum [ do
void $ R.char '('
R.skipMany . R.manyTill (R.satisfy $ const True) $ R.char ')'
void $ R.char ' '
, R.skipMany . R.manyTill (R.satisfy $ not . Char.isSpace) $ R.satisfy Char.isSpace
]
R.skipSpaces
ps <- readPrec_to_P readPrec 11
void $ R.char ' '
R.skipSpaces
qs <- readPrec_to_P readPrec 11 :: R.ReadP (HashMap Text [Text])
R.skipSpaces
when (d > 10) . void $ R.char ')'
maybe (fail "Could not parse servant route") return $ parseServantRoute (ps, ifoldMap (fmap . (,)) qs)
readPrec = readP_to_Prec $ \d -> asum
[ ServantApiBaseRoute <$ R.string "ServantApiBaseRoute"
, do
when (d > 10) . void $ R.char '('
R.skipSpaces
void $ R.string "ServantApiRoute "
R.skipSpaces
void $ R.string "_ "
R.skipSpaces
asum [ do
void $ R.char '('
R.skipMany . R.manyTill (R.satisfy $ const True) $ R.char ')'
void $ R.char ' '
, R.skipMany . R.manyTill (R.satisfy $ not . Char.isSpace) $ R.satisfy Char.isSpace
]
R.skipSpaces
ps <- readPrec_to_P readPrec 11
void $ R.char ' '
R.skipSpaces
qs <- readPrec_to_P readPrec 11 :: R.ReadP (HashMap Text [Text])
R.skipSpaces
when (d > 10) . void $ R.char ')'
maybe (fail "Could not parse servant route") return $ parseServantRoute (ps, ifoldMap (fmap . (,)) qs)
]
instance HasRoute (ServantApiUnproxy' proxy) => Show (Route (ServantApi proxy)) where
showsPrec d (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = showParen (d > 10)
$ showString "ServantApiRoute "
@ -296,6 +325,7 @@ instance HasRoute (ServantApiUnproxy' proxy) => Show (Route (ServantApi proxy))
. showsPrec 11 ps
. showString " "
. showsPrec 11 qs
showsPrec _ ServantApiBaseRoute = showString "ServantApiBaseRoute"
instance HasRoute (ServantApiUnproxy' proxy) => ParseRoute (ServantApi proxy) where
parseRoute = parseServantRoute
@ -385,10 +415,10 @@ servantApiLink _ _ = safeLink' (fromMaybe (error "Could not parse result of safe
guardEndpoint _ = Nothing
data ServantHandlerContextFor site = ServantHandlerContextFor
{ sctxSite :: site
, sctxRequest :: Request
}
class HasServantHandlerContext site where
data ServantHandlerContextFor site :: Type
getSCtxSite :: ServantHandlerContextFor site -> site
getSCtxRequest :: ServantHandlerContextFor site -> Request
newtype ServantHandlerFor site a = ServantHandlerFor { unServantHandlerFor :: ServantHandlerContextFor site -> Handler a }
deriving (Generic, Typeable)
@ -404,10 +434,10 @@ getServantContext = liftServantHandler $ ServantHandlerFor return
getsServantContext :: (site ~ site', MonadServantHandler site m) => (ServantHandlerContextFor site' -> a) -> m a
getsServantContext = liftServantHandler . ServantHandlerFor . (return .)
getYesodApproot :: (Yesod site, MonadServantHandler site m) => m Text
getYesodApproot = getsServantContext $ \ServantHandlerContextFor{..} -> Yesod.getApprootText Yesod.approot sctxSite sctxRequest
getYesodApproot :: (Yesod site, MonadSite site m, MonadRequest m) => m Text
getYesodApproot = Yesod.getApprootText Yesod.approot <$> getSite <*> getRequest
renderRouteAbsolute :: (Yesod site, MonadServantHandler site m) => Route site -> m URI
renderRouteAbsolute :: (Yesod site, MonadSite site m, MonadRequest m) => Route site -> m URI
renderRouteAbsolute (renderRoute -> (ps, qs)) = addRoute . unpack <$> getYesodApproot
where addRoute root = case parseURI root of
Just root' -> root' & uriPathLens . packed %~ addPath
@ -419,13 +449,16 @@ renderRouteAbsolute (renderRoute -> (ps, qs)) = addRoute . unpack <$> getYesodAp
addQuery "?" = addQuery ""
addQuery q = q <> "&" <> tailEx (addQuery "")
class MonadIO m => MonadServantHandler site m | m -> site where
servantApiBaseUrl :: (Yesod site, MonadSite site m, MonadRequest m, MonadThrow m) => (Route (ServantApi proxy) -> Route site) -> m BaseUrl
servantApiBaseUrl = parseBaseUrl . ($ mempty). uriToString (const "") <=< renderRouteAbsolute . ($ ServantApiBaseRoute)
class (MonadIO m, HasServantHandlerContext site) => MonadServantHandler site m | m -> site where
liftServantHandler :: forall a. ServantHandlerFor site a -> m a
instance MonadServantHandler site (ServantHandlerFor site) where
instance HasServantHandlerContext site => MonadServantHandler site (ServantHandlerFor site) where
liftServantHandler = id
instance (MonadTrans t, MonadIO (t (ServantHandlerFor site))) => MonadServantHandler site (t (ServantHandlerFor site)) where
instance (MonadTrans t, MonadIO (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadServantHandler site (t (ServantHandlerFor site)) where
liftServantHandler = lift
class MonadIO m => MonadHandler m where
@ -443,8 +476,8 @@ class Monad m => MonadSite site m | m -> site where
getsSite :: (site -> a) -> m a
getsSite f = f <$> getSite
instance MonadSite site (ServantHandlerFor site) where
getSite = liftServantHandler . ServantHandlerFor $ return . sctxSite
instance HasServantHandlerContext site => MonadSite site (ServantHandlerFor site) where
getSite = liftServantHandler . ServantHandlerFor $ return . getSCtxSite
instance MonadSite site (Reader site) where
getSite = ask
@ -454,10 +487,22 @@ instance {-# OVERLAPPABLE #-} (Yesod.MonadHandler m, site ~ Yesod.HandlerSite m)
getSite = Yesod.getYesod
getsSite = Yesod.getsYesod
instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site))) => MonadSite site (t (ServantHandlerFor site)) where
instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadSite site (t (ServantHandlerFor site)) where
getSite = lift getSite
getsSite = lift . getsSite
class Monad m => MonadRequest m where
getRequest :: m Request
instance HasServantHandlerContext site => MonadRequest (ServantHandlerFor site) where
getRequest = liftServantHandler . ServantHandlerFor $ return . getSCtxRequest
instance {-# OVERLAPPABLE #-} (Yesod.MonadHandler m, Monad m) => MonadRequest m where
getRequest = Yesod.waiRequest
instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadRequest (t (ServantHandlerFor site)) where
getRequest = lift getRequest
type ServantDBFor site = ReaderT (Yesod.YesodPersistBackend site) (ServantHandlerFor site)
@ -466,6 +511,7 @@ class Yesod.YesodPersist site => ServantPersist site where
defaultRunDB :: ( PersistConfig c
, ServantDBFor site a ~ PersistConfigBackend c (ServantHandlerFor site) a
, HasServantHandlerContext site
)
=> Getting c site c
-> Getting (PersistConfigPool c) site (PersistConfigPool c)
@ -485,12 +531,12 @@ instance Yesod site => ServantLog (ServantLogYesod site) where
logger <- Yesod.makeLogger app
Yesod.messageLoggerSource app logger a b c d
instance ServantLog site => MonadLogger (ServantHandlerFor site) where
instance (ServantLog site, HasServantHandlerContext site) => MonadLogger (ServantHandlerFor site) where
monadLoggerLog a b c d = do
app <- getSite
servantLogLog app a b c d
instance ServantLog site => MonadLoggerIO (ServantHandlerFor site) where
instance (ServantLog site, HasServantHandlerContext site) => MonadLoggerIO (ServantHandlerFor site) where
askLoggerIO = servantLogLog <$> getSite
@ -501,56 +547,6 @@ instance PathPiece a => FromHttpApiData (PathPieceHttpApiData a) where
parseUrlPiece = maybe (Left "Could not convert from HttpApiData via PathPiece") Right . fromPathPiece
instance PathPiece a => ToHttpApiData (PathPieceHttpApiData a) where
toUrlPiece = toPathPiece
data BearerAuth
data SessionAuth
instance HasSwagger sub => HasSwagger (BearerAuth :> sub) where
toSwagger _ = toSwagger (Proxy @sub)
& securityDefinitions <>~ SecurityDefinitions (fromList [(defnKey, defn)])
& allOperations . security <>~ [SecurityRequirement $ fromList [(defnKey, [])]]
where defnKey :: Text
defnKey = "bearer"
defn = SecurityScheme
{ _securitySchemeType
= SecuritySchemeApiKey ApiKeyParams
{ _apiKeyName = "Authorization"
, _apiKeyIn = ApiKeyHeader
}
, _securitySchemeDescription = Just
"JSON Web Token-based API key"
}
instance HasSwagger sub => HasSwagger (SessionAuth :> sub) where
toSwagger _ = toSwagger (Proxy @sub)
& allOperations . security <>~ [SecurityRequirement mempty]
-- We do not expect API clients to be able/willing to conform with
-- our CSRF mitigation, so we mark routes that require it as
-- having unfullfillable security requirements
instance HasLink sub => HasLink (BearerAuth :> sub) where
type MkLink (BearerAuth :> sub) a = MkLink sub a
toLink toA _ = toLink toA (Proxy @sub)
instance HasLink sub => HasLink (SessionAuth :> sub) where
type MkLink (SessionAuth :> sub) a = MkLink sub a
toLink toA _ = toLink toA (Proxy @sub)
instance HasDocs sub => HasDocs (BearerAuth :> sub) where
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
where action' = action & authInfo %~ (|> authInfo')
authInfo' = DocAuthentication
""
"A JSON Web Token-based API key"
instance HasDocs sub => HasDocs (SessionAuth :> sub) where
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
where action' = action & authInfo %~ (|> authInfo')
authInfo' = DocAuthentication
"When a web session is used for authorization, CSRF-mitigation measures must be observed."
"An active web session identifying the user as one with sufficient authorization"
mkYesodApi :: Name -> [ResourceTree String] -> DecsQ

View File

@ -51,6 +51,8 @@ import qualified Data.SemVer as SemVer
import qualified Data.SemVer.Constraint as SemVer (Constraint)
import qualified Data.SemVer.Constraint as SemVer.Constraint
import qualified Data.HashSet as HashSet
instance Arbitrary Season where
@ -343,7 +345,9 @@ instance Arbitrary RoomReference' where
arbitrary = genericArbitrary
instance Arbitrary ExternalApiConfig where
arbitrary = genericArbitrary
arbitrary = oneof
[ EApiGradelistFormat <$> ((fmap HashSet.fromList . scale (`div` 10) $ listOf1 (resize 3 arbitrary)) `suchThatMap` fromNullable)
]
shrink = genericShrink
instance Arbitrary SemVer.Version where

View File

@ -7,7 +7,14 @@ import Network.URI
import Network.URI.Arbitrary ()
import Servant.Client.Core.BaseUrl
import Control.Lens.Extras
instance Arbitrary BaseUrl where
arbitrary = toBaseUrl <$> arbitrary
where toBaseUrl = either (error . displayException) id . parseBaseUrl . ($ mempty) . uriToString id
arbitrary = (`suchThatMap` toBaseUrl) $ do
uri <- scale (min 10) arbitrary `suchThat` (is _Just . uriAuthority)
uriScheme <- oneof $ map (return . (<> ":")) [ "http", "https" ]
let uriAuthority'' = uriAuthority uri <&> \uriAuthority' -> uriAuthority'{ uriUserInfo = "" }
return (uri, uriScheme, uriAuthority'')
where
toBaseUrl (uri, uriScheme, uriAuthority'') = either (const Nothing) Just . parseBaseUrl . ($ mempty) $ uriToString (const mempty) uri{ uriScheme, uriAuthority = uriAuthority'', uriQuery = "", uriFragment = "" }

View File

@ -8,7 +8,10 @@ import ServantApi.ExternalApis.Type
instance Arbitrary ExternalApiCreationRequest where
arbitrary = genericArbitrary
arbitrary = ExternalApiCreationRequest
<$> scale (`div` 2) arbitrary
<*> scale (`div` 2) arbitrary
<*> scale (`div` 2) arbitrary
shrink = genericShrink

View File

@ -0,0 +1,48 @@
{-# OPTIONS_GHC -Wno-error=unused-local-binds #-}
module ServantApi.ExternalApisSpec where
import TestImport
import ServantApi.ExternalApis.Type
import ServantApi.ExternalApis.TypeSpec ()
import Servant.Client.Core (RequestF(..))
import Servant.Client.Generic
import Utils.Tokens
import Data.Time.Clock (nominalDay)
import qualified Data.HashSet as HashSet
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Sequence as Seq
import Control.Monad.Reader.Class (MonadReader(local))
import Utils (CustomHeader(..), waiCustomHeader)
spec :: Spec
spec = withApp . describe "ExternalApis" $ do
it "Supports dryRun" $ do
adminId <- runDB $ do
Entity adminId _ <- insertEntity $ fakeUser id
ifi <- insert $ School "Institut für Informatik" "IfI" (Just $ 14 * nominalDay) (Just $ 10 * nominalDay) True (ExamModeDNF predDNFFalse) (ExamCloseOnFinished True) SchoolAuthorshipStatementModeOptional Nothing True SchoolAuthorshipStatementModeRequired Nothing False
insert_ $ UserFunction adminId ifi SchoolAdmin
return adminId
accessToken <- runHandler $ encodeBearer =<< bearerToken (HashSet.singleton $ Right adminId) Nothing HashMap.empty Nothing Nothing Nothing
let
insertExternalApi = void $ externalApisCreateR accessToken =<< liftIO (generate $ resize 10 arbitrary)
where ExternalApis{..} = genericClient
withDryRun :: ServantExampleEnv -> ServantExampleEnv
withDryRun seEnv = seEnv
{ yseMakeClientRequest = \burl req -> yseMakeClientRequest seEnv burl req{ requestHeaders = requestHeaders req Seq.:|> waiCustomHeader HeaderDryRun True }
}
externalApiCount = runDB $ count @_ @_ @ExternalApi []
runServantExample ExternalApisR insertExternalApi
liftIO . (`shouldBe` 1) =<< externalApiCount
runServantExample ExternalApisR $ local withDryRun insertExternalApi
liftIO . (`shouldBe` 1) =<< externalApiCount

36
test/ServantApiSpec.hs Normal file
View File

@ -0,0 +1,36 @@
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module ServantApiSpec where
import TestImport
import ServantApi
import Servant.API
import Servant.API.TypeLevel (MapSub, AppendList)
import Foundation.Servant.Types (ApiVersion)
import GHC.TypeLits
import Data.Kind (Constraint)
type family Unversioned api where
Unversioned (ApiVersion _ _ _ :> _) = '[]
Unversioned (sup :> sub) = MapSub sup (Unversioned sub)
Unversioned (a :<|> b) = AppendList (Unversioned a) (Unversioned b)
Unversioned (Verb method statusCode contentTypes a) = '[Verb method statusCode contentTypes a]
Unversioned (NoContentVerb method) = '[NoContentVerb method]
type family UnversionedError xs :: ErrorMessage where
UnversionedError (x ': '[]) = 'Text "Unversioned API endpoint: " ':$$: ('Text " " ':<>: 'ShowType x)
UnversionedError (x ': xs) = UnversionedError (x ': '[]) ':$$: UnversionedError xs
type family IsEmpty xs :: Constraint where
IsEmpty '[] = ()
IsEmpty xs = TypeError ('Text "All API endpoints must be versioned." ':$$: UnversionedError xs)
spec :: Spec
spec = describe "Servant endpoints" $ it "are all versioned" versioned
where
versioned :: IsEmpty (Unversioned UniWorXApi) => Bool
versioned = True

View File

@ -1,3 +1,5 @@
{-# OPTIONS_GHC -fno-warn-deprecations #-}
module TestImport
( module TestImport
, module X
@ -44,6 +46,34 @@ import Jobs (handleJobs)
import Numeric.Natural as X
import Network.URI.Arbitrary as X ()
import qualified Network.Wai as Wai
import qualified Network.Wai.Test as Wai
import qualified Network.Wai.Test.Internal as Wai (ClientState)
import Network.HTTP.Types (Status(..), hContentType, hAccept)
import Network.HTTP.Types.Header (hHost)
import qualified Network.HTTP.Types as Wai
import Control.Monad.Trans.Except (ExceptT)
import qualified Servant.Client.Core as Servant
import Servant.Client.Core.ClientError
import Servant.Client.Core.RunClient
import Control.Monad.Except (MonadError(..))
import Control.Monad.State.Class (MonadState(..))
import qualified Control.Monad.State.Class as State
import qualified Servant.Types.SourceT as S
import Servant.API (SourceIO)
import Utils (throwExceptT)
import Yesod.Servant (ServantApi, servantApiBaseUrl)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as Lazy (ByteString)
import qualified Data.ByteString.Lazy as LBS hiding (ByteString)
import qualified Data.Binary.Builder as B
import Network.HTTP.Media (renderHeader)
import Control.Monad.Fail
import Control.Lens as X hiding ((<.), elements)
import Network.IP.Addr as X (IP)
@ -133,3 +163,105 @@ lawsCheckHspec p = parallel . describe (show $ typeRep p) . mapM_ (checkHspec .
where
checkHspec (Laws className properties) = describe className $
forM_ properties $ \(name, prop) -> it name $ property prop
newtype ServantExample a = ServantExample
{ unServantExample :: ReaderT ServantExampleEnv (ExceptT ClientError Wai.Session) a
} deriving stock (Generic, Typeable)
deriving newtype (Functor, Applicative, Monad, MonadIO, MonadReader ServantExampleEnv, MonadError ClientError, MonadThrow, MonadCatch, MonadState Wai.ClientState)
data ServantExampleEnv = ServantExampleEnv
{ yseBaseUrl :: BaseUrl
, yseMakeClientRequest :: BaseUrl -> Servant.Request -> IO Wai.Request
} deriving (Generic, Typeable)
runServantExample :: (Route (ServantApi proxy) -> Route UniWorX) -> ServantExample a -> YesodExample UniWorX a
runServantExample apiR (ServantExample act) = do
yseBaseUrl <- runHandler $ servantApiBaseUrl apiR
let yseMakeClientRequest burl Servant.Request{..} = do
((body, bodyLength), contentTypeHdr) <- case requestBody of
Nothing -> return ((return BS.empty, Wai.KnownLength 0), Nothing)
Just (body', typ) -> let (mkBody, bLength) = convertBody body'
in (, Just (hContentType, renderHeader typ)) . (, bLength) <$> mkBody
return $ Wai.defaultRequest
{ Wai.requestMethod = requestMethod
, Wai.requestHeaders = maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers
, Wai.requestHeaderHost =
let BaseUrl{..} = yseBaseUrl
in Just . encodeUtf8 . pack $ baseUrlHost <> bool (":" <> show baseUrlPort) mempty (baseUrlPort == 80)
, Wai.requestBody = body, Wai.requestBodyLength = bodyLength
, Wai.isSecure = isSecure
}
& flip Wai.setPath (encodeUtf8 (pack $ baseUrlPath burl) <> toStrict (B.toLazyByteString requestPath) <> Wai.renderQuery True (toList requestQueryString))
where
headers = filter (\(h, _) -> h `notElem` [hAccept, hContentType, hHost]) $ toList requestHeaders
acceptHdr
| null hs = Nothing
| otherwise = Just (hAccept, renderHeader hs)
where
hs = toList requestAccept
convertBody :: Servant.RequestBody -> (IO (IO ByteString), Wai.RequestBodyLength)
convertBody bd = case bd of
Servant.RequestBodyLBS body' -> ( givesPopper . S.source . map fromStrict $ LBS.toChunks body'
, Wai.KnownLength . fromIntegral $ LBS.length body'
)
Servant.RequestBodyBS body' -> ( return $ return body'
, Wai.KnownLength . fromIntegral $ BS.length body'
)
Servant.RequestBodySource sourceIO -> ( givesPopper sourceIO
, Wai.ChunkedBody
)
where
givesPopper :: SourceIO Lazy.ByteString -> IO (IO ByteString)
givesPopper sourceIO = S.unSourceT sourceIO $ \step0 -> do
ref <- newMVar step0
return $ modifyMVar ref nextBs
nextBs S.Stop = return (S.Stop, BS.empty)
nextBs (S.Error err) = fail err
nextBs (S.Skip s) = nextBs s
nextBs (S.Effect ms) = ms >>= nextBs
nextBs (S.Yield lbs s) = case LBS.toChunks lbs of
[] -> nextBs s
(x:xs) | BS.null x -> nextBs step'
| otherwise -> return (step', x)
where
step' = S.Yield (LBS.fromChunks xs) s
isSecure = case baseUrlScheme burl of
Servant.Http -> False
Servant.Https -> True
YesodExampleData waiApp _ _ _ <- State.get
liftIO . flip Wai.runSession waiApp . throwExceptT $ runReaderT act ServantExampleEnv{..}
instance RunClient ServantExample where
runRequestAcceptStatus acceptStatus req = do
ServantExampleEnv{..} <- ask
waiRequest <- liftIO $ yseMakeClientRequest yseBaseUrl req
waiResponse@Wai.SResponse{..} <- ServantExample . lift . lift $ Wai.request waiRequest
let Status{..} = simpleStatus
statusOk = case acceptStatus of
Nothing -> 200 <= statusCode && statusCode < 300
Just good -> simpleStatus `elem` good
response = (waiResponseToResponse waiResponse) { Servant.responseHttpVersion = Wai.httpVersion waiRequest }
unless statusOk $
throwError $ mkFailureResponse yseBaseUrl req response
return response
where
mkFailureResponse :: BaseUrl -> Servant.Request -> Servant.ResponseF Lazy.ByteString -> ClientError
mkFailureResponse burl request' =
FailureResponse (bimap (const ()) f request')
where
f b = (burl, LBS.toStrict $ B.toLazyByteString b)
waiResponseToResponse :: Wai.SResponse -> Servant.Response
waiResponseToResponse Wai.SResponse{..} = Servant.Response
{ responseStatusCode = simpleStatus
, responseBody = simpleBody
, responseHeaders = fromList simpleHeaders
, responseHttpVersion = error "WAI Response does not carry http version information"
}
throwClientError = throwError