feat(servant): dry-run support
This commit is contained in:
parent
605b7758e6
commit
47df8a312f
@ -346,6 +346,7 @@ tests:
|
||||
- quickcheck-io
|
||||
- network-arbitrary
|
||||
- lens-properties
|
||||
- http-media
|
||||
ghc-options:
|
||||
- -fno-warn-orphans
|
||||
- -threaded -rtsopts "-with-rtsopts=-N -T"
|
||||
|
||||
@ -9,6 +9,7 @@ module Foundation.Authorization
|
||||
, wouldHaveReadAccessToIff, wouldHaveWriteAccessToIff
|
||||
, AuthContext(..), getAuthContext
|
||||
, isDryRun, isDryRunDB
|
||||
, IsDryRun(..)
|
||||
, maybeBearerToken, requireBearerToken
|
||||
, requireCurrentBearerRestrictions, maybeCurrentBearerRestrictions
|
||||
, BearerAuthSite, MonadAP
|
||||
@ -276,7 +277,9 @@ getAuthContext = liftHandler $ do
|
||||
return authCtx
|
||||
|
||||
newtype IsDryRun = MkIsDryRun { unIsDryRun :: Bool }
|
||||
deriving (Eq, Ord, Read, Show, Generic, Typeable)
|
||||
deriving stock (Read, Show, Generic, Typeable)
|
||||
deriving newtype (Eq, Ord)
|
||||
deriving (Semigroup, Monoid) via Any
|
||||
|
||||
isDryRun :: ( HasCallStack
|
||||
, BearerAuthSite UniWorX
|
||||
@ -296,6 +299,7 @@ isDryRunDB :: forall m backend m'.
|
||||
isDryRunDB = fmap unIsDryRun . cached . fmap MkIsDryRun $ orM
|
||||
[ hasGlobalPostParam PostDryRun
|
||||
, hasGlobalGetParam GetDryRun
|
||||
, hasCustomHeader HeaderDryRun
|
||||
, and2M bearerDryRun bearerRequired
|
||||
]
|
||||
where
|
||||
|
||||
@ -9,19 +9,15 @@ module Foundation.Servant
|
||||
|
||||
import Import.Servant.NoFoundation
|
||||
import Foundation.DB (runSqlPoolRetry')
|
||||
import Foundation.Authorization (maybeBearerToken)
|
||||
import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun)
|
||||
import Foundation.Instances ()
|
||||
|
||||
import qualified Data.HashMap.Strict.InsOrd as HashMap
|
||||
|
||||
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders, vault)
|
||||
import Network.Wai (Middleware, modifyResponse, mapResponseHeaders)
|
||||
import qualified Network.Wai as W
|
||||
|
||||
import qualified Data.Vault.Lazy as Vault
|
||||
|
||||
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal, withRequest)
|
||||
|
||||
import System.IO.Unsafe (unsafePerformIO)
|
||||
import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal)
|
||||
|
||||
import qualified Yesod.Servant as Servant
|
||||
|
||||
@ -32,21 +28,16 @@ import Control.Monad.Catch.Pure
|
||||
import Servant.Server.Internal.Delayed
|
||||
import Servant.Server.Internal.Router
|
||||
|
||||
-- import Database.Persist.Sql (transactionUndo)
|
||||
import Database.Persist.Sql (transactionUndo)
|
||||
|
||||
|
||||
waiBearerKey :: Vault.Key (Maybe (BearerToken UniWorX))
|
||||
waiBearerKey = unsafePerformIO Vault.newKey
|
||||
{-# NOINLINE waiBearerKey #-}
|
||||
|
||||
waiRouteKey :: Vault.Key (Route UniWorX)
|
||||
waiRouteKey = unsafePerformIO Vault.newKey
|
||||
{-# NOINLINE waiRouteKey #-}
|
||||
import qualified Data.CaseInsensitive as CI
|
||||
|
||||
|
||||
instance ( HasServer sub context
|
||||
, ToJSON restr, FromJSON restr
|
||||
, SBoolI (FoldRequired mods)
|
||||
, HasContextEntry context (Maybe (BearerToken UniWorX))
|
||||
, HasContextEntry context (Maybe (Route UniWorX))
|
||||
)
|
||||
=> HasServer (CaptureBearerRestriction' mods restr :> sub) context
|
||||
where
|
||||
@ -56,25 +47,25 @@ instance ( HasServer sub context
|
||||
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
|
||||
|
||||
route _ context subserver
|
||||
= route (Proxy @sub) context (subserver `addAuthCheck` withRequest bearerCheck)
|
||||
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
|
||||
where
|
||||
bearerCheck :: W.Request -> DelayedIO (RequiredArgument mods restr)
|
||||
bearerCheck req = do
|
||||
let bearer = Vault.lookup waiBearerKey $ vault req
|
||||
cRoute = Vault.lookup waiRouteKey $ vault req
|
||||
bearerCheck :: DelayedIO (RequiredArgument mods restr)
|
||||
bearerCheck = do
|
||||
let bearer :: Maybe (BearerToken UniWorX)
|
||||
bearer = getContextEntry context
|
||||
cRoute :: Maybe (Route UniWorX)
|
||||
cRoute = getContextEntry context
|
||||
|
||||
noRouteStored, noTokenStored, noTokenProvided, noRestrictionProvided :: ServerError
|
||||
noTokenStored = err500 { errBody = "servantYesodMiddleware did not store bearer token in WAI vault." }
|
||||
noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError
|
||||
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
|
||||
noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." }
|
||||
noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." }
|
||||
|
||||
exceptT delayedFailFatal return $ do
|
||||
bearer' <- maybeExceptT' noTokenStored bearer
|
||||
cRoute' <- maybeExceptT' noRouteStored cRoute
|
||||
|
||||
let mbRet :: Maybe (Maybe restr)
|
||||
mbRet = bearer' <&> preview (_bearerRestrictionIx cRoute')
|
||||
mbRet = bearer <&> preview (_bearerRestrictionIx cRoute')
|
||||
case sbool @(FoldRequired mods) of
|
||||
SFalse -> return $ join mbRet
|
||||
STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet
|
||||
@ -82,6 +73,7 @@ instance ( HasServer sub context
|
||||
|
||||
instance ( HasServer sub context
|
||||
, SBoolI (FoldRequired mods)
|
||||
, HasContextEntry context (Maybe (BearerToken UniWorX))
|
||||
)
|
||||
=> HasServer (CaptureBearerToken' mods :> sub) context
|
||||
where
|
||||
@ -91,21 +83,20 @@ instance ( HasServer sub context
|
||||
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s
|
||||
|
||||
route _ context subserver
|
||||
= route (Proxy @sub) context (subserver `addAuthCheck` withRequest bearerCheck)
|
||||
= route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck)
|
||||
where
|
||||
bearerCheck :: W.Request -> DelayedIO (RequiredArgument mods (BearerToken UniWorX))
|
||||
bearerCheck req = do
|
||||
let bearer = Vault.lookup waiBearerKey $ vault req
|
||||
bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX))
|
||||
bearerCheck = do
|
||||
let bearer :: Maybe (BearerToken UniWorX)
|
||||
bearer = getContextEntry context
|
||||
|
||||
noTokenStored, noTokenProvided :: ServerError
|
||||
noTokenStored = err500 { errBody = "servantYesodMiddleware did not store bearer token in WAI vault." }
|
||||
noTokenProvided :: ServerError
|
||||
noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." }
|
||||
|
||||
exceptT delayedFailFatal return $ do
|
||||
bearer' <- maybeExceptT' noTokenStored bearer
|
||||
case sbool @(FoldRequired mods) of
|
||||
SFalse -> return bearer'
|
||||
STrue -> maybe (throwE noTokenProvided) return bearer'
|
||||
SFalse -> return bearer
|
||||
STrue -> maybe (throwE noTokenProvided) return bearer
|
||||
|
||||
|
||||
instance ( HasServer sub context
|
||||
@ -132,23 +123,40 @@ instance ( HasServer sub context
|
||||
decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context
|
||||
|
||||
|
||||
type UniWorXContext = UniWorX ': '[]
|
||||
type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[]
|
||||
type ServantHandler = ServantHandlerFor UniWorX
|
||||
type ServantDB = ServantDBFor UniWorX
|
||||
|
||||
deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX
|
||||
|
||||
instance HasServantHandlerContext UniWorX where
|
||||
data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX
|
||||
{ usctxSite :: UniWorX
|
||||
, usctxRequest :: W.Request
|
||||
, usctxIsDryRun :: IsDryRun
|
||||
}
|
||||
getSCtxSite = usctxSite
|
||||
getSCtxRequest = usctxRequest
|
||||
|
||||
class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where
|
||||
servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler
|
||||
|
||||
instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where
|
||||
servantContext _ app _ = return $ app :. EmptyContext
|
||||
servantHoist _ sctxSite sctxRequest _ = ($ ServantHandlerContextFor{..}) . unServantHandlerFor
|
||||
servantMiddleware _ _ _ = modifyResponse (mapResponseHeaders setDefaultHeaders) . fixTrailingSlash
|
||||
servantYesodMiddleware _ _ = appEndo <$> foldMapM (fmap Endo) [storeBearerToken, storeCurrentRoute]
|
||||
servantContext _ app _ = do
|
||||
isDryRun' <- MkIsDryRun <$> isDryRun
|
||||
restr <- maybeBearerToken
|
||||
cRoute <- getCurrentRoute
|
||||
return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext
|
||||
servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor
|
||||
servantMiddleware _ _ ctx = appEndo . foldMap Endo $
|
||||
guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader)
|
||||
++ [ modifyResponse (mapResponseHeaders setDefaultHeaders)
|
||||
, fixTrailingSlash
|
||||
]
|
||||
servantYesodMiddleware _ _ = return id
|
||||
servantServer proxy _ = servantServer' proxy
|
||||
|
||||
setDefaultHeaders :: ResponseHeaders -> ResponseHeaders
|
||||
setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders
|
||||
setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders
|
||||
where defaultHeaders = HashMap.fromList
|
||||
[ ("X-Frame-Options", "sameorigin")
|
||||
@ -156,6 +164,7 @@ setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defau
|
||||
, ("Vary", "Accept")
|
||||
, ("X-XSS-Protection", "1; mode=block")
|
||||
]
|
||||
setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True)
|
||||
|
||||
fixTrailingSlash :: Middleware
|
||||
-- ^ `servant-server` contains a special case in their implementation
|
||||
@ -170,17 +179,6 @@ fixTrailingSlash = (. fixTrailingSlash')
|
||||
| otherwise
|
||||
= req
|
||||
|
||||
storeBearerToken, storeCurrentRoute :: HandlerFor UniWorX Middleware
|
||||
storeBearerToken = do
|
||||
restr <- maybeBearerToken
|
||||
return $ \app req -> app req{ vault = Vault.insert waiBearerKey restr $ vault req }
|
||||
storeCurrentRoute = do
|
||||
cRoute <- getCurrentRoute
|
||||
|
||||
$logDebugS "storeCurrentRoute" $ tshow cRoute
|
||||
|
||||
return $ \app req -> app req{ vault = maybe id (Vault.insert waiRouteKey) cRoute $ vault req }
|
||||
|
||||
|
||||
instance ServantPersist UniWorX where
|
||||
runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
|
||||
@ -189,9 +187,9 @@ instance ServantPersist UniWorX where
|
||||
runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a
|
||||
runDB' lbl action = do
|
||||
$logDebugS "ServantPersist" "runDB"
|
||||
-- let action' = do
|
||||
-- dryRun <- isDryRunDB
|
||||
-- if | dryRun -> action <* transactionUndo
|
||||
-- | otherwise -> action
|
||||
MkIsDryRun dryRun <- getsServantContext usctxIsDryRun
|
||||
let action'
|
||||
| dryRun = action <* transactionUndo
|
||||
| otherwise = action
|
||||
|
||||
flip (runSqlPoolRetry' action) lbl . appConnPool =<< getSite
|
||||
flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite
|
||||
|
||||
@ -5,12 +5,14 @@ module Foundation.Servant.Types
|
||||
, CaptureBearerToken, CaptureBearerToken'
|
||||
, CaptureCryptoID', CaptureCryptoID, CaptureCryptoUUID, CaptureCryptoFileName
|
||||
, ApiVersion, apiVersionToSemVer, matchesApiVersion
|
||||
, BearerAuth, SessionAuth
|
||||
) where
|
||||
|
||||
import ClassyPrelude
|
||||
import ClassyPrelude hiding (fromList)
|
||||
import Data.Proxy
|
||||
|
||||
import Servant.API
|
||||
import Servant.API.Modifiers (FoldRequired)
|
||||
import Servant.API.Description
|
||||
import Servant.Swagger
|
||||
import Servant.Docs
|
||||
@ -21,6 +23,13 @@ import Servant.Server.Internal.Delayed
|
||||
import Servant.Server.Internal.ErrorFormatter
|
||||
-- import Servant.Server.Internal.DelayedIO
|
||||
|
||||
import Servant.Client.Core.RunClient (RunClient)
|
||||
import Servant.Client.Core.HasClient
|
||||
import qualified Servant.Client.Core.Request as Servant (Request)
|
||||
import qualified Servant.Client.Core.Request as Request
|
||||
|
||||
import Jose.Jwt (Jwt(..))
|
||||
|
||||
import Network.Wai (mapResponseHeaders, requestHeaders)
|
||||
|
||||
import Control.Lens hiding (Context)
|
||||
@ -31,8 +40,9 @@ import Data.CryptoID.Class.ImplicitNamespace
|
||||
import Data.CryptoID.Instances ()
|
||||
|
||||
import GHC.TypeLits
|
||||
import GHC.Exts (IsList(..))
|
||||
|
||||
import Data.Swagger (ToParamSchema)
|
||||
import Data.Swagger hiding (version)
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
@ -114,6 +124,11 @@ instance HasDocs sub => HasDocs (CaptureBearerToken' mods :> sub) where
|
||||
instance (ToCapture (Capture sym ciphertext), KnownSymbol sym, HasDocs sub) => HasDocs (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where
|
||||
docsFor _ = docsFor $ Proxy @(Capture' mods sym ciphertext :> sub)
|
||||
|
||||
instance (RunClient m, HasClient m (Capture' mods sym (CryptoID ciphertext plaintext) :> sub)) => HasClient m (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where
|
||||
type Client m (CaptureCryptoID' mods ciphertext sym plaintext :> sub) = Client m (Capture' mods sym (CryptoID ciphertext plaintext) :> sub)
|
||||
clientWithRoute pm _ = clientWithRoute pm $ Proxy @(Capture' mods sym (CryptoID ciphertext plaintext) :> sub)
|
||||
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(Capture' mods sym (CryptoID ciphertext plaintext) :> sub)
|
||||
|
||||
|
||||
type family ApiVersionSub major minor patch sup sub where
|
||||
ApiVersionSub major minor patch (ApiVersion major' minor' patch') sub = ApiVersion major' minor' patch' :> sub
|
||||
@ -143,8 +158,30 @@ instance ( HasServer (ApiVersion major minor patch :> a) context
|
||||
choice' = case (sbool :: SBool (IsLT (CmpVersion (FinalApiVersion (ApiVersion major minor patch :> a)) (FinalApiVersion (ApiVersion major minor patch :> b))))) of
|
||||
STrue -> flip choice
|
||||
SFalse -> choice
|
||||
|
||||
instance (RunClient m, HasClient m (ApiVersionSub major minor patch sup sub)) => HasClient m (ApiVersion major minor patch :> ((sup :: Type) :> sub)) where
|
||||
type Client m (ApiVersion major minor patch :> (sup :> sub)) = Client m (ApiVersionSub major minor patch sup sub)
|
||||
clientWithRoute pm _ = clientWithRoute pm $ Proxy @(ApiVersionSub major minor patch sup sub)
|
||||
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(ApiVersionSub major minor patch sup sub)
|
||||
|
||||
instance (RunClient m, HasClient m (sup :> (ApiVersion major minor patch :> sub))) => HasClient m (ApiVersion major minor patch :> ((sup :: Symbol) :> sub)) where
|
||||
type Client m (ApiVersion major minor patch :> (sup :> sub)) = Client m (sup :> (ApiVersion major minor patch :> sub))
|
||||
clientWithRoute pm _ = clientWithRoute pm $ Proxy @(sup :> (ApiVersion major minor patch :> sub))
|
||||
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(sup :> (ApiVersion major minor patch :> sub))
|
||||
|
||||
instance ( HasClient m (ApiVersion major minor patch :> a)
|
||||
, HasClient m (ApiVersion major minor patch :> b)
|
||||
) => HasClient m (ApiVersion major minor patch :> (a :<|> b)) where
|
||||
type Client m (ApiVersion major minor patch :> (a :<|> b)) = Client m (ApiVersion major minor patch :> a) :<|> Client m (ApiVersion major minor patch :> b)
|
||||
clientWithRoute pm _ req = clientWithRoute pm (Proxy @(ApiVersion major minor patch :> a)) req
|
||||
:<|> clientWithRoute pm (Proxy @(ApiVersion major minor patch :> b)) req
|
||||
hoistClientMonad pm _ f (ca :<|> cb) = hoistClientMonad pm (Proxy @(ApiVersion major minor patch :> a)) f ca
|
||||
:<|> hoistClientMonad pm (Proxy @(ApiVersion major minor patch :> b)) f cb
|
||||
|
||||
|
||||
versionRequestHeaderName :: CI ByteString
|
||||
versionRequestHeaderName = "Accept-API-Version"
|
||||
|
||||
routeWithApiVersion :: forall api context env major minor patch.
|
||||
( HasServer api context
|
||||
, KnownNat major, KnownNat minor, KnownNat patch
|
||||
@ -168,7 +205,6 @@ routeWithApiVersion _ _ context subserver = RawRouter $ \env req ((. addVersion)
|
||||
version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch)
|
||||
|
||||
versionHeaderName = "API-Version"
|
||||
versionRequestHeaderName = "Accept-API-Version"
|
||||
versionHeader = encodeUtf8 $ SemVer.toText version
|
||||
|
||||
notFound = notFoundErrorFormatter . getContextEntry $ mkContextWithErrorFormatter context
|
||||
@ -194,6 +230,26 @@ instance ( HasServer (NoContentVerb method) context
|
||||
route _ = routeWithApiVersion (Proxy @(ApiVersion major minor patch)) (Proxy @(NoContentVerb method))
|
||||
|
||||
|
||||
semVerCompatibleTo :: SemVer.Version -> SemVer.Constraint
|
||||
semVerCompatibleTo v = SemVer.Constraint.CAnd (SemVer.Constraint.CGtEq v) (SemVer.Constraint.CLt $ SemVer.incrementMajor v)
|
||||
|
||||
instance ( HasClient m (Verb method statusCode contentTypes a)
|
||||
, KnownNat major, KnownNat minor, KnownNat patch
|
||||
) => HasClient m (ApiVersion major minor patch :> Verb method statusCode contentTypes a) where
|
||||
type Client m (ApiVersion major minor patch :> Verb method statusCode contentTypes a) = Client m (Verb method statusCode contentTypes a)
|
||||
clientWithRoute pm _ = clientWithRoute pm (Proxy @(Verb method statusCode contentTypes a)) . Request.addHeader versionRequestHeaderName (semVerCompatibleTo version)
|
||||
where version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch)
|
||||
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(Verb method statusCode contentTypes a)
|
||||
|
||||
instance ( HasClient m (NoContentVerb method)
|
||||
, KnownNat major, KnownNat minor, KnownNat patch
|
||||
) => HasClient m (ApiVersion major minor patch :> NoContentVerb method) where
|
||||
type Client m (ApiVersion major minor patch :> NoContentVerb method) = Client m (NoContentVerb method)
|
||||
clientWithRoute pm _ = clientWithRoute pm (Proxy @(NoContentVerb method)) . Request.addHeader versionRequestHeaderName (semVerCompatibleTo version)
|
||||
where version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch)
|
||||
hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(NoContentVerb method)
|
||||
|
||||
|
||||
instance ( HasDocs (ApiVersionSub major minor patch sup sub)
|
||||
) => HasDocs (ApiVersion major minor patch :> ((sup :: Type) :> sub)) where
|
||||
docsFor _ = docsFor $ Proxy @(ApiVersionSub major minor patch sup sub)
|
||||
@ -263,3 +319,105 @@ type family IsLT x where
|
||||
type instance IsElem' sa (CaptureCryptoID' mods ciphertext sym plaintext :> sb) = IsElem sa (Capture' mods sym (CryptoID ciphertext plaintext) :> sb)
|
||||
|
||||
type instance IsElem' sa (ApiVersion major minor patch :> sb) = IsElem sa sb
|
||||
|
||||
|
||||
type family StripBearer api where
|
||||
StripBearer (CaptureBearerRestriction' mods restr :> sub) = sub
|
||||
StripBearer (CaptureBearerToken' mods :> sub) = sub
|
||||
StripBearer (BearerAuth :> sub) = sub
|
||||
StripBearer (sup :> sub) = sup :> StripBearer sub
|
||||
StripBearer (a :<|> b) = StripBearer a :<|> StripBearer b
|
||||
StripBearer (Verb method statusCode contentTypes a) = Verb method statusCode contentTypes a
|
||||
StripBearer (NoContentVerb method) = NoContentVerb method
|
||||
|
||||
type family BearerRequired api where
|
||||
BearerRequired (CaptureBearerRestriction' mods restr :> sub) = OrBool (FoldRequired mods) (BearerRequired sub)
|
||||
BearerRequired (CaptureBearerToken' mods :> sub) = OrBool (FoldRequired mods) (BearerRequired sub)
|
||||
BearerRequired (BearerAuth :> sub) = 'True
|
||||
BearerRequired (sup :> sub) = BearerRequired sub
|
||||
BearerRequired (a :<|> b) = OrBool (BearerRequired a) (BearerRequired b)
|
||||
BearerRequired (Verb method statusCode contentTypes a) = 'False
|
||||
BearerRequired (NoContentVerb method) = 'False
|
||||
|
||||
type family OrBool a b where
|
||||
OrBool 'False 'False = 'False
|
||||
OrBool a b = 'True
|
||||
|
||||
maybeWithJwt :: forall (a :: Bool). SBoolI a => Proxy a -> If a Jwt (Maybe Jwt) -> Servant.Request -> Servant.Request
|
||||
maybeWithJwt _ mparam = case (sbool :: SBool a, mparam) of
|
||||
(STrue, jwt) -> add jwt
|
||||
(SFalse, mJwt) -> maybe id add mJwt
|
||||
where add (Jwt jwt) = Request.addHeader "Authorization" . decodeUtf8 $ "Bearer " <> jwt
|
||||
|
||||
instance ( HasClient m (StripBearer sub)
|
||||
, RunClient m
|
||||
, SBoolI (BearerRequired (CaptureBearerRestriction' mods restr :> sub))
|
||||
) => HasClient m (CaptureBearerRestriction' mods restr :> sub) where
|
||||
type Client m (CaptureBearerRestriction' mods restr :> sub) = If (BearerRequired (CaptureBearerRestriction' mods restr :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub)
|
||||
clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (CaptureBearerRestriction' mods restr :> sub))) mparam req
|
||||
hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl
|
||||
|
||||
instance ( HasClient m (StripBearer sub)
|
||||
, RunClient m
|
||||
, SBoolI (BearerRequired (CaptureBearerToken' mods :> sub))
|
||||
) => HasClient m (CaptureBearerToken' mods :> sub) where
|
||||
type Client m (CaptureBearerToken' mods :> sub) = If (BearerRequired (CaptureBearerToken' mods :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub)
|
||||
clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (CaptureBearerToken' mods :> sub))) mparam req
|
||||
hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl
|
||||
|
||||
instance ( HasClient m (StripBearer sub)
|
||||
, RunClient m
|
||||
, SBoolI (BearerRequired (BearerAuth :> sub))
|
||||
) => HasClient m (BearerAuth :> sub) where
|
||||
type Client m (BearerAuth :> sub) = If (BearerRequired (BearerAuth :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub)
|
||||
clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (BearerAuth :> sub))) mparam req
|
||||
hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl
|
||||
|
||||
|
||||
data BearerAuth
|
||||
data SessionAuth
|
||||
|
||||
instance HasSwagger sub => HasSwagger (BearerAuth :> sub) where
|
||||
toSwagger _ = toSwagger (Proxy @sub)
|
||||
& securityDefinitions <>~ SecurityDefinitions (fromList [(defnKey, defn)])
|
||||
& allOperations . security <>~ [SecurityRequirement $ fromList [(defnKey, [])]]
|
||||
where defnKey :: Text
|
||||
defnKey = "bearer"
|
||||
defn = SecurityScheme
|
||||
{ _securitySchemeType
|
||||
= SecuritySchemeApiKey ApiKeyParams
|
||||
{ _apiKeyName = "Authorization"
|
||||
, _apiKeyIn = ApiKeyHeader
|
||||
}
|
||||
, _securitySchemeDescription = Just
|
||||
"JSON Web Token-based API key"
|
||||
}
|
||||
|
||||
instance HasSwagger sub => HasSwagger (SessionAuth :> sub) where
|
||||
toSwagger _ = toSwagger (Proxy @sub)
|
||||
& allOperations . security <>~ [SecurityRequirement mempty]
|
||||
-- We do not expect API clients to be able/willing to conform with
|
||||
-- our CSRF mitigation, so we mark routes that require it as
|
||||
-- having unfullfillable security requirements
|
||||
|
||||
instance HasLink sub => HasLink (BearerAuth :> sub) where
|
||||
type MkLink (BearerAuth :> sub) a = MkLink sub a
|
||||
toLink toA _ = toLink toA (Proxy @sub)
|
||||
|
||||
instance HasLink sub => HasLink (SessionAuth :> sub) where
|
||||
type MkLink (SessionAuth :> sub) a = MkLink sub a
|
||||
toLink toA _ = toLink toA (Proxy @sub)
|
||||
|
||||
instance HasDocs sub => HasDocs (BearerAuth :> sub) where
|
||||
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
|
||||
where action' = action & authInfo %~ (|> authInfo')
|
||||
authInfo' = DocAuthentication
|
||||
""
|
||||
"A JSON Web Token-based API key"
|
||||
|
||||
instance HasDocs sub => HasDocs (SessionAuth :> sub) where
|
||||
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
|
||||
where action' = action & authInfo %~ (|> authInfo')
|
||||
authInfo' = DocAuthentication
|
||||
"When a web session is used for authorization, CSRF-mitigation measures must be observed."
|
||||
"An active web session identifying the user as one with sufficient authorization"
|
||||
|
||||
@ -40,7 +40,7 @@ import Yesod.Core.Types as Import (loggerSet)
|
||||
import Yesod.Default.Config2 as Import
|
||||
import Yesod.Core.Types.Instances as Import
|
||||
import Yesod.Servant as Import
|
||||
hiding ( MonadHandler(..), HasRoute(..)
|
||||
hiding ( MonadHandler(..), HasRoute(..), MonadRequest(..)
|
||||
, runDB, defaultRunDB
|
||||
)
|
||||
import Servant.Docs as Import
|
||||
@ -210,6 +210,7 @@ import Data.MonoTraversable.Instances as Import ()
|
||||
import Servant.Client.Core.BaseUrl.Instances as Import ()
|
||||
import Control.Monad.Trans.Except.Instances as Import ()
|
||||
import Servant.Server.Instances as Import ()
|
||||
import Servant.Docs.Internal.Pretty.Instances as Import ()
|
||||
import Network.URI.Instances as Import ()
|
||||
import Data.HashSet.Instances as Import ()
|
||||
import Web.Cookie.Instances as Import ()
|
||||
|
||||
@ -14,6 +14,7 @@ import Import.NoFoundation as Import hiding
|
||||
, MonadHandler(..), HasRoute(..), liftHandler
|
||||
, encrypt, decrypt
|
||||
, Unique, Fragment(..), respond
|
||||
, getRequest
|
||||
)
|
||||
|
||||
import Yesod.Servant as Import
|
||||
|
||||
14
src/Servant/Docs/Internal/Pretty/Instances.hs
Normal file
14
src/Servant/Docs/Internal/Pretty/Instances.hs
Normal file
@ -0,0 +1,14 @@
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Servant.Docs.Internal.Pretty.Instances () where
|
||||
|
||||
import ClassyPrelude
|
||||
|
||||
import Servant.Docs.Internal.Pretty
|
||||
import Servant.API.ContentTypes
|
||||
|
||||
import Data.Proxy
|
||||
|
||||
|
||||
instance MimeUnrender JSON a => MimeUnrender PrettyJSON a where
|
||||
mimeUnrender _ = mimeUnrender $ Proxy @JSON
|
||||
@ -15,17 +15,22 @@ import Jose.Jwk (JwkSet(..))
|
||||
{-# ANN module ("HLint: ignore Use newtype instead of data" :: String) #-}
|
||||
|
||||
|
||||
type ExternalApisListR = Get '[PrettyJSON] ExternalApisList
|
||||
type ExternalApisCreateR = CaptureBearerRestriction' '[Optional] ExternalApiCreationRestrictions
|
||||
type ExternalApisListR = ApiVersion 1 0 0
|
||||
:> Get '[PrettyJSON] ExternalApisList
|
||||
type ExternalApisCreateR = ApiVersion 1 0 0
|
||||
:> CaptureBearerRestriction' '[Optional] ExternalApiCreationRestrictions
|
||||
:> CaptureBearerToken
|
||||
:> ReqBody '[JSON] ExternalApiCreationRequest
|
||||
:> PostCreated '[PrettyJSON] (Headers '[Header "Location" URI] ExternalApiCreationResponse)
|
||||
type ExternalApisPongR = CaptureCryptoUUID "external-api" ExternalApiId
|
||||
type ExternalApisPongR = ApiVersion 1 0 0
|
||||
:> CaptureCryptoUUID "external-api" ExternalApiId
|
||||
:> "pong"
|
||||
:> Post '[PrettyJSON] ExternalApiPongResponse
|
||||
type ExternalApisInfoR = CaptureCryptoUUID "external-api" ExternalApiId
|
||||
type ExternalApisInfoR = ApiVersion 1 0 0
|
||||
:> CaptureCryptoUUID "external-api" ExternalApiId
|
||||
:> Get '[PrettyJSON] ExternalApiInfo
|
||||
type ExternalApisDeleteR = CaptureCryptoUUID "external-api" ExternalApiId
|
||||
type ExternalApisDeleteR = ApiVersion 1 0 0
|
||||
:> CaptureCryptoUUID "external-api" ExternalApiId
|
||||
:> DeleteNoContent
|
||||
|
||||
data ExternalApis mode = ExternalApis
|
||||
@ -37,7 +42,7 @@ data ExternalApis mode = ExternalApis
|
||||
} deriving (Generic)
|
||||
|
||||
type ServantApiExternalApis = ServantApi ExternalApis
|
||||
type instance ServantApiUnproxy ExternalApis = ApiVersion 1 0 0 :> ToServantApi ExternalApis
|
||||
type instance ServantApiUnproxy ExternalApis = ToServantApi ExternalApis
|
||||
|
||||
|
||||
instance ToCapture (Capture "external-api" UUID) where
|
||||
@ -122,7 +127,9 @@ data ExternalApiInfo = ExternalApiInfo
|
||||
|
||||
instance ToJSON ExternalApiInfo where
|
||||
toJSON ExternalApiInfo{..} = object $ maybe id ((:) . ("ident" .=)) eaiIdent
|
||||
[ "token-authority" .= foldMap (HashSet.singleton . either id toJSON) eaiTokenAuthority
|
||||
[ "token-authority" .= case HashSet.toList eaiTokenAuthority of
|
||||
[x] -> either id toJSON x
|
||||
_ -> toJSON $ foldMap (HashSet.singleton . either id toJSON) eaiTokenAuthority
|
||||
, "token-issued" .= eaiTokenIssued
|
||||
, "token-expires-at" .= eaiTokenExpiresAt
|
||||
, "token-starts-at" .= eaiTokenStartsAt
|
||||
@ -134,11 +141,11 @@ instance ToJSON ExternalApiInfo where
|
||||
|
||||
instance FromJSON ExternalApiInfo where
|
||||
parseJSON = withObject "ExternalApiInfo" $ \o -> do
|
||||
eaiIdent <- o .:? "token-authority"
|
||||
eaiIdent <- o .:? "ident"
|
||||
eaiTokenAuthority <- asum
|
||||
[ HashSet.singleton . Right <$> o .: "authority"
|
||||
, (o .: "authority" :: _ (HashSet Value)) >>= foldMapM (\v' -> fmap HashSet.singleton $ (Right <$> parseJSON v') <|> return (Left v'))
|
||||
, HashSet.singleton . Left <$> o .: "authority"
|
||||
[ HashSet.singleton . Right <$> o .: "token-authority"
|
||||
, (o .: "token-authority" :: _ (HashSet Value)) >>= foldMapM (\v' -> fmap HashSet.singleton $ (Right <$> parseJSON v') <|> return (Left v'))
|
||||
, HashSet.singleton . Left <$> o .: "token-authority"
|
||||
]
|
||||
eaiTokenIssued <- o .: "token-issued"
|
||||
eaiTokenExpiresAt <- o .: "token-expires-at"
|
||||
|
||||
@ -113,7 +113,9 @@ import Data.Binary (Binary)
|
||||
import qualified Data.Binary as Binary
|
||||
|
||||
import Network.Wai (requestMethod)
|
||||
import Network.HTTP.Types.Header
|
||||
import Network.HTTP.Types.Header as Wai
|
||||
|
||||
import Web.HttpApiData
|
||||
|
||||
import Data.Time.Clock
|
||||
|
||||
@ -1143,6 +1145,9 @@ addCustomHeader, replaceOrAddCustomHeader :: (MonadHandler m, PathPiece payload)
|
||||
addCustomHeader ident payload = addHeader (toPathPiece ident) (toPathPiece payload)
|
||||
replaceOrAddCustomHeader ident payload = replaceOrAddHeader (toPathPiece ident) (toPathPiece payload)
|
||||
|
||||
waiCustomHeader :: ToHttpApiData payload => CustomHeader -> payload -> Wai.Header
|
||||
waiCustomHeader ident payload = (CI.mk . encodeUtf8 $ toPathPiece ident, toHeader payload)
|
||||
|
||||
------------------
|
||||
-- HTTP Headers --
|
||||
------------------
|
||||
|
||||
@ -8,8 +8,8 @@ module Yesod.Servant
|
||||
, ServantApiDispatch(..)
|
||||
, servantApiLink
|
||||
, ServantHandlerFor(..)
|
||||
, ServantHandlerContextFor(..), getServantContext, getsServantContext, getYesodApproot, renderRouteAbsolute
|
||||
, MonadServantHandler(..), MonadHandler(..), MonadSite(..)
|
||||
, HasServantHandlerContext(..), getServantContext, getsServantContext, getYesodApproot, renderRouteAbsolute, servantApiBaseUrl
|
||||
, MonadServantHandler(..), MonadHandler(..), MonadSite(..), MonadRequest(..)
|
||||
, ServantDBFor, ServantPersist(..), defaultRunDB
|
||||
, ServantLog(..), ServantLogYesod(..)
|
||||
, mkYesodApi
|
||||
@ -45,6 +45,8 @@ import Servant.API
|
||||
import Servant.Server hiding (route)
|
||||
import Servant.Server.Instances ()
|
||||
|
||||
import Servant.Client.Core.BaseUrl
|
||||
|
||||
import Data.Proxy
|
||||
|
||||
import Network.Wai (Request, Middleware)
|
||||
@ -59,13 +61,10 @@ import Control.Monad.Fail (MonadFail(..))
|
||||
|
||||
import Data.Data (Data)
|
||||
import Data.Kind (Type)
|
||||
import GHC.Exts (IsList(..), Constraint)
|
||||
import GHC.Exts (Constraint)
|
||||
|
||||
import Servant.Swagger
|
||||
import Data.Swagger
|
||||
|
||||
import Servant.Docs
|
||||
|
||||
import qualified Data.Set as Set
|
||||
|
||||
import Network.HTTP.Types.Status
|
||||
@ -138,75 +137,92 @@ instance (Typeable m, Typeable k, Typeable status, Typeable fr, Typeable ct, Typ
|
||||
instance HasRoute sub => HasRoute (HttpVersion :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(HttpVersion :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance HasRoute sub => HasRoute (Vault :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Vault :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, KnownSymbol realm, Typeable a) => HasRoute (BasicAuth realm a :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(BasicAuth realm a :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, KnownSymbol s) => HasRoute (Description s :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Description s :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, KnownSymbol s) => HasRoute (Summary s :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Summary s :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, Typeable tag, Typeable k) => HasRoute (AuthProtect (tag :: k) :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(AuthProtect tag :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance HasRoute sub => HasRoute (IsSecure :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(IsSecure :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance HasRoute sub => HasRoute (RemoteHost :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(RemoteHost :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, Typeable mods, Typeable restr) => HasRoute (CaptureBearerRestriction' mods restr :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(CaptureBearerRestriction' mods restr :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, Typeable mods) => HasRoute (CaptureBearerToken' mods :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(CaptureBearerToken' mods :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (KnownSymbol sym, HasRoute sub, HasLink sub) => HasRoute (sym :> sub) where
|
||||
parseServantRoute (p : ps, qs)
|
||||
| p == escapedSymbol (Proxy @sym)
|
||||
= parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(sym :> endpoint)) f (escapedSymbol (Proxy @sym) : ps') qs'
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
parseServantRoute _ = Nothing
|
||||
|
||||
instance (HasRoute a, HasRoute b) => HasRoute (a :<|> b) where
|
||||
parseServantRoute args = asum
|
||||
[ parseServantRoute @a @(ServantApiDirect a) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @endpoint) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
, parseServantRoute @b @(ServantApiDirect b) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @endpoint) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
]
|
||||
|
||||
instance (HasRoute sub, Typeable mods, Typeable ct, Typeable a) => HasRoute (ReqBody' mods ct a :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(ReqBody' mods ct a :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, Typeable mods, Typeable framing, Typeable ct, Typeable a) => HasRoute (StreamBody' mods framing ct a :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(StreamBody' mods framing ct a :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, KnownSymbol sym, Typeable mods, Typeable a) => HasRoute (Header' mods sym (a :: Type) :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Header' mods sym a :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable v, ToHttpApiDataInjective v, FromHttpApiData v) => HasRoute (Capture' mods sym (v :: Type) :> sub) where
|
||||
parseServantRoute (p : ps, qs)
|
||||
| Right v <- parseUrlPiece @v p
|
||||
= parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(Capture' mods sym v :> endpoint)) (f . ($ v)) (toUrlPieceInjective v : ps') qs'
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
parseServantRoute _ = Nothing
|
||||
|
||||
instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable plaintext, ToHttpApiDataInjective ciphertext, FromHttpApiData ciphertext, Typeable ciphertext) => HasRoute (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where
|
||||
@ -214,11 +230,13 @@ instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable plaintext, ToHt
|
||||
| Right v <- parseUrlPiece @(CryptoID ciphertext plaintext) p
|
||||
= parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(CaptureCryptoID' mods ciphertext sym plaintext :> endpoint)) (f . ($ v)) (toUrlPieceInjective v : ps') qs'
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
parseServantRoute _ = Nothing
|
||||
|
||||
instance (HasRoute sub, KnownNat major, KnownNat minor, KnownNat patch) => HasRoute (ApiVersion major minor patch :> sub) where
|
||||
parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case
|
||||
ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(ApiVersion major minor patch :> endpoint)) f ps qs
|
||||
ServantApiBaseRoute -> ServantApiBaseRoute
|
||||
|
||||
|
||||
data ServantApi (proxy :: k) = ServantApi
|
||||
@ -249,45 +267,56 @@ instance HasRoute (ServantApiUnproxy' proxy) => RenderRoute (ServantApi proxy) w
|
||||
(Proxy endpoint)
|
||||
(forall a. MkLink endpoint a -> a)
|
||||
[Text] (HashMap Text [Text])
|
||||
| ServantApiBaseRoute
|
||||
renderRoute (ServantApiRoute (_ :: Proxy endpoint) f _ _) = f $ safeLink' renderServantRoute (Proxy @(ServantApiUnproxy' proxy)) (Proxy @endpoint)
|
||||
renderRoute ServantApiBaseRoute = mempty
|
||||
|
||||
instance HasRoute (ServantApiUnproxy' proxy) => Eq (Route (ServantApi proxy)) where
|
||||
(ServantApiRoute (_ :: Proxy endpoint) _ ps qs) == (ServantApiRoute (_ :: Proxy endpoint') _ ps' qs')
|
||||
= case eqT @endpoint @endpoint' of
|
||||
Just Refl -> ps == ps' && qs == qs'
|
||||
Nothing -> False
|
||||
ServantApiBaseRoute == ServantApiBaseRoute = True
|
||||
_ == _ = False
|
||||
|
||||
instance HasRoute (ServantApiUnproxy' proxy) => Ord (Route (ServantApi proxy)) where
|
||||
compare (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) (ServantApiRoute (_ :: Proxy endpoint') _ ps' qs')
|
||||
= case eqT @endpoint @endpoint' of
|
||||
Just Refl -> compare ps ps' <> compare qs qs'
|
||||
Nothing -> typeRep (Proxy @endpoint) `compare` typeRep (Proxy @endpoint')
|
||||
compare ServantApiBaseRoute ServantApiBaseRoute = EQ
|
||||
compare ServantApiBaseRoute _ = LT
|
||||
compare _ ServantApiBaseRoute = GT
|
||||
|
||||
instance HasRoute (ServantApiUnproxy' proxy) => Hashable (Route (ServantApi proxy)) where
|
||||
hashWithSalt salt (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = salt `hashWithSalt` typeRep (Proxy @endpoint) `hashWithSalt` ps `hashWithSalt` qs
|
||||
hashWithSalt salt (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = salt `hashWithSalt` (0 :: Int) `hashWithSalt` typeRep (Proxy @endpoint) `hashWithSalt` ps `hashWithSalt` qs
|
||||
hashWithSalt salt ServantApiBaseRoute = salt `hashWithSalt` (1 :: Int)
|
||||
|
||||
instance HasRoute (ServantApiUnproxy' proxy) => Read (Route (ServantApi proxy)) where
|
||||
readPrec = readP_to_Prec $ \d -> do
|
||||
when (d > 10) . void $ R.char '('
|
||||
R.skipSpaces
|
||||
void $ R.string "ServantApiRoute "
|
||||
R.skipSpaces
|
||||
void $ R.string "_ "
|
||||
R.skipSpaces
|
||||
asum [ do
|
||||
void $ R.char '('
|
||||
R.skipMany . R.manyTill (R.satisfy $ const True) $ R.char ')'
|
||||
void $ R.char ' '
|
||||
, R.skipMany . R.manyTill (R.satisfy $ not . Char.isSpace) $ R.satisfy Char.isSpace
|
||||
]
|
||||
R.skipSpaces
|
||||
ps <- readPrec_to_P readPrec 11
|
||||
void $ R.char ' '
|
||||
R.skipSpaces
|
||||
qs <- readPrec_to_P readPrec 11 :: R.ReadP (HashMap Text [Text])
|
||||
R.skipSpaces
|
||||
when (d > 10) . void $ R.char ')'
|
||||
maybe (fail "Could not parse servant route") return $ parseServantRoute (ps, ifoldMap (fmap . (,)) qs)
|
||||
readPrec = readP_to_Prec $ \d -> asum
|
||||
[ ServantApiBaseRoute <$ R.string "ServantApiBaseRoute"
|
||||
, do
|
||||
when (d > 10) . void $ R.char '('
|
||||
R.skipSpaces
|
||||
void $ R.string "ServantApiRoute "
|
||||
R.skipSpaces
|
||||
void $ R.string "_ "
|
||||
R.skipSpaces
|
||||
asum [ do
|
||||
void $ R.char '('
|
||||
R.skipMany . R.manyTill (R.satisfy $ const True) $ R.char ')'
|
||||
void $ R.char ' '
|
||||
, R.skipMany . R.manyTill (R.satisfy $ not . Char.isSpace) $ R.satisfy Char.isSpace
|
||||
]
|
||||
R.skipSpaces
|
||||
ps <- readPrec_to_P readPrec 11
|
||||
void $ R.char ' '
|
||||
R.skipSpaces
|
||||
qs <- readPrec_to_P readPrec 11 :: R.ReadP (HashMap Text [Text])
|
||||
R.skipSpaces
|
||||
when (d > 10) . void $ R.char ')'
|
||||
maybe (fail "Could not parse servant route") return $ parseServantRoute (ps, ifoldMap (fmap . (,)) qs)
|
||||
]
|
||||
instance HasRoute (ServantApiUnproxy' proxy) => Show (Route (ServantApi proxy)) where
|
||||
showsPrec d (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = showParen (d > 10)
|
||||
$ showString "ServantApiRoute "
|
||||
@ -296,6 +325,7 @@ instance HasRoute (ServantApiUnproxy' proxy) => Show (Route (ServantApi proxy))
|
||||
. showsPrec 11 ps
|
||||
. showString " "
|
||||
. showsPrec 11 qs
|
||||
showsPrec _ ServantApiBaseRoute = showString "ServantApiBaseRoute"
|
||||
|
||||
instance HasRoute (ServantApiUnproxy' proxy) => ParseRoute (ServantApi proxy) where
|
||||
parseRoute = parseServantRoute
|
||||
@ -385,10 +415,10 @@ servantApiLink _ _ = safeLink' (fromMaybe (error "Could not parse result of safe
|
||||
guardEndpoint _ = Nothing
|
||||
|
||||
|
||||
data ServantHandlerContextFor site = ServantHandlerContextFor
|
||||
{ sctxSite :: site
|
||||
, sctxRequest :: Request
|
||||
}
|
||||
class HasServantHandlerContext site where
|
||||
data ServantHandlerContextFor site :: Type
|
||||
getSCtxSite :: ServantHandlerContextFor site -> site
|
||||
getSCtxRequest :: ServantHandlerContextFor site -> Request
|
||||
|
||||
newtype ServantHandlerFor site a = ServantHandlerFor { unServantHandlerFor :: ServantHandlerContextFor site -> Handler a }
|
||||
deriving (Generic, Typeable)
|
||||
@ -404,10 +434,10 @@ getServantContext = liftServantHandler $ ServantHandlerFor return
|
||||
getsServantContext :: (site ~ site', MonadServantHandler site m) => (ServantHandlerContextFor site' -> a) -> m a
|
||||
getsServantContext = liftServantHandler . ServantHandlerFor . (return .)
|
||||
|
||||
getYesodApproot :: (Yesod site, MonadServantHandler site m) => m Text
|
||||
getYesodApproot = getsServantContext $ \ServantHandlerContextFor{..} -> Yesod.getApprootText Yesod.approot sctxSite sctxRequest
|
||||
getYesodApproot :: (Yesod site, MonadSite site m, MonadRequest m) => m Text
|
||||
getYesodApproot = Yesod.getApprootText Yesod.approot <$> getSite <*> getRequest
|
||||
|
||||
renderRouteAbsolute :: (Yesod site, MonadServantHandler site m) => Route site -> m URI
|
||||
renderRouteAbsolute :: (Yesod site, MonadSite site m, MonadRequest m) => Route site -> m URI
|
||||
renderRouteAbsolute (renderRoute -> (ps, qs)) = addRoute . unpack <$> getYesodApproot
|
||||
where addRoute root = case parseURI root of
|
||||
Just root' -> root' & uriPathLens . packed %~ addPath
|
||||
@ -419,13 +449,16 @@ renderRouteAbsolute (renderRoute -> (ps, qs)) = addRoute . unpack <$> getYesodAp
|
||||
addQuery "?" = addQuery ""
|
||||
addQuery q = q <> "&" <> tailEx (addQuery "")
|
||||
|
||||
class MonadIO m => MonadServantHandler site m | m -> site where
|
||||
servantApiBaseUrl :: (Yesod site, MonadSite site m, MonadRequest m, MonadThrow m) => (Route (ServantApi proxy) -> Route site) -> m BaseUrl
|
||||
servantApiBaseUrl = parseBaseUrl . ($ mempty). uriToString (const "") <=< renderRouteAbsolute . ($ ServantApiBaseRoute)
|
||||
|
||||
class (MonadIO m, HasServantHandlerContext site) => MonadServantHandler site m | m -> site where
|
||||
liftServantHandler :: forall a. ServantHandlerFor site a -> m a
|
||||
|
||||
instance MonadServantHandler site (ServantHandlerFor site) where
|
||||
instance HasServantHandlerContext site => MonadServantHandler site (ServantHandlerFor site) where
|
||||
liftServantHandler = id
|
||||
|
||||
instance (MonadTrans t, MonadIO (t (ServantHandlerFor site))) => MonadServantHandler site (t (ServantHandlerFor site)) where
|
||||
instance (MonadTrans t, MonadIO (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadServantHandler site (t (ServantHandlerFor site)) where
|
||||
liftServantHandler = lift
|
||||
|
||||
class MonadIO m => MonadHandler m where
|
||||
@ -443,8 +476,8 @@ class Monad m => MonadSite site m | m -> site where
|
||||
getsSite :: (site -> a) -> m a
|
||||
getsSite f = f <$> getSite
|
||||
|
||||
instance MonadSite site (ServantHandlerFor site) where
|
||||
getSite = liftServantHandler . ServantHandlerFor $ return . sctxSite
|
||||
instance HasServantHandlerContext site => MonadSite site (ServantHandlerFor site) where
|
||||
getSite = liftServantHandler . ServantHandlerFor $ return . getSCtxSite
|
||||
|
||||
instance MonadSite site (Reader site) where
|
||||
getSite = ask
|
||||
@ -454,10 +487,22 @@ instance {-# OVERLAPPABLE #-} (Yesod.MonadHandler m, site ~ Yesod.HandlerSite m)
|
||||
getSite = Yesod.getYesod
|
||||
getsSite = Yesod.getsYesod
|
||||
|
||||
instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site))) => MonadSite site (t (ServantHandlerFor site)) where
|
||||
instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadSite site (t (ServantHandlerFor site)) where
|
||||
getSite = lift getSite
|
||||
getsSite = lift . getsSite
|
||||
|
||||
class Monad m => MonadRequest m where
|
||||
getRequest :: m Request
|
||||
|
||||
instance HasServantHandlerContext site => MonadRequest (ServantHandlerFor site) where
|
||||
getRequest = liftServantHandler . ServantHandlerFor $ return . getSCtxRequest
|
||||
|
||||
instance {-# OVERLAPPABLE #-} (Yesod.MonadHandler m, Monad m) => MonadRequest m where
|
||||
getRequest = Yesod.waiRequest
|
||||
|
||||
instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadRequest (t (ServantHandlerFor site)) where
|
||||
getRequest = lift getRequest
|
||||
|
||||
|
||||
type ServantDBFor site = ReaderT (Yesod.YesodPersistBackend site) (ServantHandlerFor site)
|
||||
|
||||
@ -466,6 +511,7 @@ class Yesod.YesodPersist site => ServantPersist site where
|
||||
|
||||
defaultRunDB :: ( PersistConfig c
|
||||
, ServantDBFor site a ~ PersistConfigBackend c (ServantHandlerFor site) a
|
||||
, HasServantHandlerContext site
|
||||
)
|
||||
=> Getting c site c
|
||||
-> Getting (PersistConfigPool c) site (PersistConfigPool c)
|
||||
@ -485,12 +531,12 @@ instance Yesod site => ServantLog (ServantLogYesod site) where
|
||||
logger <- Yesod.makeLogger app
|
||||
Yesod.messageLoggerSource app logger a b c d
|
||||
|
||||
instance ServantLog site => MonadLogger (ServantHandlerFor site) where
|
||||
instance (ServantLog site, HasServantHandlerContext site) => MonadLogger (ServantHandlerFor site) where
|
||||
monadLoggerLog a b c d = do
|
||||
app <- getSite
|
||||
servantLogLog app a b c d
|
||||
|
||||
instance ServantLog site => MonadLoggerIO (ServantHandlerFor site) where
|
||||
instance (ServantLog site, HasServantHandlerContext site) => MonadLoggerIO (ServantHandlerFor site) where
|
||||
askLoggerIO = servantLogLog <$> getSite
|
||||
|
||||
|
||||
@ -501,56 +547,6 @@ instance PathPiece a => FromHttpApiData (PathPieceHttpApiData a) where
|
||||
parseUrlPiece = maybe (Left "Could not convert from HttpApiData via PathPiece") Right . fromPathPiece
|
||||
instance PathPiece a => ToHttpApiData (PathPieceHttpApiData a) where
|
||||
toUrlPiece = toPathPiece
|
||||
|
||||
|
||||
data BearerAuth
|
||||
data SessionAuth
|
||||
|
||||
instance HasSwagger sub => HasSwagger (BearerAuth :> sub) where
|
||||
toSwagger _ = toSwagger (Proxy @sub)
|
||||
& securityDefinitions <>~ SecurityDefinitions (fromList [(defnKey, defn)])
|
||||
& allOperations . security <>~ [SecurityRequirement $ fromList [(defnKey, [])]]
|
||||
where defnKey :: Text
|
||||
defnKey = "bearer"
|
||||
defn = SecurityScheme
|
||||
{ _securitySchemeType
|
||||
= SecuritySchemeApiKey ApiKeyParams
|
||||
{ _apiKeyName = "Authorization"
|
||||
, _apiKeyIn = ApiKeyHeader
|
||||
}
|
||||
, _securitySchemeDescription = Just
|
||||
"JSON Web Token-based API key"
|
||||
}
|
||||
|
||||
instance HasSwagger sub => HasSwagger (SessionAuth :> sub) where
|
||||
toSwagger _ = toSwagger (Proxy @sub)
|
||||
& allOperations . security <>~ [SecurityRequirement mempty]
|
||||
-- We do not expect API clients to be able/willing to conform with
|
||||
-- our CSRF mitigation, so we mark routes that require it as
|
||||
-- having unfullfillable security requirements
|
||||
|
||||
instance HasLink sub => HasLink (BearerAuth :> sub) where
|
||||
type MkLink (BearerAuth :> sub) a = MkLink sub a
|
||||
toLink toA _ = toLink toA (Proxy @sub)
|
||||
|
||||
instance HasLink sub => HasLink (SessionAuth :> sub) where
|
||||
type MkLink (SessionAuth :> sub) a = MkLink sub a
|
||||
toLink toA _ = toLink toA (Proxy @sub)
|
||||
|
||||
instance HasDocs sub => HasDocs (BearerAuth :> sub) where
|
||||
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
|
||||
where action' = action & authInfo %~ (|> authInfo')
|
||||
authInfo' = DocAuthentication
|
||||
""
|
||||
"A JSON Web Token-based API key"
|
||||
|
||||
instance HasDocs sub => HasDocs (SessionAuth :> sub) where
|
||||
docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action')
|
||||
where action' = action & authInfo %~ (|> authInfo')
|
||||
authInfo' = DocAuthentication
|
||||
"When a web session is used for authorization, CSRF-mitigation measures must be observed."
|
||||
"An active web session identifying the user as one with sufficient authorization"
|
||||
|
||||
|
||||
|
||||
mkYesodApi :: Name -> [ResourceTree String] -> DecsQ
|
||||
|
||||
@ -51,6 +51,8 @@ import qualified Data.SemVer as SemVer
|
||||
import qualified Data.SemVer.Constraint as SemVer (Constraint)
|
||||
import qualified Data.SemVer.Constraint as SemVer.Constraint
|
||||
|
||||
import qualified Data.HashSet as HashSet
|
||||
|
||||
|
||||
|
||||
instance Arbitrary Season where
|
||||
@ -343,7 +345,9 @@ instance Arbitrary RoomReference' where
|
||||
arbitrary = genericArbitrary
|
||||
|
||||
instance Arbitrary ExternalApiConfig where
|
||||
arbitrary = genericArbitrary
|
||||
arbitrary = oneof
|
||||
[ EApiGradelistFormat <$> ((fmap HashSet.fromList . scale (`div` 10) $ listOf1 (resize 3 arbitrary)) `suchThatMap` fromNullable)
|
||||
]
|
||||
shrink = genericShrink
|
||||
|
||||
instance Arbitrary SemVer.Version where
|
||||
|
||||
@ -7,7 +7,14 @@ import Network.URI
|
||||
import Network.URI.Arbitrary ()
|
||||
import Servant.Client.Core.BaseUrl
|
||||
|
||||
import Control.Lens.Extras
|
||||
|
||||
|
||||
instance Arbitrary BaseUrl where
|
||||
arbitrary = toBaseUrl <$> arbitrary
|
||||
where toBaseUrl = either (error . displayException) id . parseBaseUrl . ($ mempty) . uriToString id
|
||||
arbitrary = (`suchThatMap` toBaseUrl) $ do
|
||||
uri <- scale (min 10) arbitrary `suchThat` (is _Just . uriAuthority)
|
||||
uriScheme <- oneof $ map (return . (<> ":")) [ "http", "https" ]
|
||||
let uriAuthority'' = uriAuthority uri <&> \uriAuthority' -> uriAuthority'{ uriUserInfo = "" }
|
||||
return (uri, uriScheme, uriAuthority'')
|
||||
where
|
||||
toBaseUrl (uri, uriScheme, uriAuthority'') = either (const Nothing) Just . parseBaseUrl . ($ mempty) $ uriToString (const mempty) uri{ uriScheme, uriAuthority = uriAuthority'', uriQuery = "", uriFragment = "" }
|
||||
|
||||
@ -8,7 +8,10 @@ import ServantApi.ExternalApis.Type
|
||||
|
||||
|
||||
instance Arbitrary ExternalApiCreationRequest where
|
||||
arbitrary = genericArbitrary
|
||||
arbitrary = ExternalApiCreationRequest
|
||||
<$> scale (`div` 2) arbitrary
|
||||
<*> scale (`div` 2) arbitrary
|
||||
<*> scale (`div` 2) arbitrary
|
||||
shrink = genericShrink
|
||||
|
||||
|
||||
|
||||
48
test/ServantApi/ExternalApisSpec.hs
Normal file
48
test/ServantApi/ExternalApisSpec.hs
Normal file
@ -0,0 +1,48 @@
|
||||
{-# OPTIONS_GHC -Wno-error=unused-local-binds #-}
|
||||
|
||||
module ServantApi.ExternalApisSpec where
|
||||
|
||||
import TestImport
|
||||
import ServantApi.ExternalApis.Type
|
||||
import ServantApi.ExternalApis.TypeSpec ()
|
||||
|
||||
import Servant.Client.Core (RequestF(..))
|
||||
import Servant.Client.Generic
|
||||
|
||||
import Utils.Tokens
|
||||
import Data.Time.Clock (nominalDay)
|
||||
|
||||
import qualified Data.HashSet as HashSet
|
||||
import qualified Data.HashMap.Strict as HashMap
|
||||
|
||||
import qualified Data.Sequence as Seq
|
||||
|
||||
import Control.Monad.Reader.Class (MonadReader(local))
|
||||
import Utils (CustomHeader(..), waiCustomHeader)
|
||||
|
||||
|
||||
spec :: Spec
|
||||
spec = withApp . describe "ExternalApis" $ do
|
||||
it "Supports dryRun" $ do
|
||||
adminId <- runDB $ do
|
||||
Entity adminId _ <- insertEntity $ fakeUser id
|
||||
ifi <- insert $ School "Institut für Informatik" "IfI" (Just $ 14 * nominalDay) (Just $ 10 * nominalDay) True (ExamModeDNF predDNFFalse) (ExamCloseOnFinished True) SchoolAuthorshipStatementModeOptional Nothing True SchoolAuthorshipStatementModeRequired Nothing False
|
||||
insert_ $ UserFunction adminId ifi SchoolAdmin
|
||||
return adminId
|
||||
|
||||
accessToken <- runHandler $ encodeBearer =<< bearerToken (HashSet.singleton $ Right adminId) Nothing HashMap.empty Nothing Nothing Nothing
|
||||
|
||||
let
|
||||
insertExternalApi = void $ externalApisCreateR accessToken =<< liftIO (generate $ resize 10 arbitrary)
|
||||
where ExternalApis{..} = genericClient
|
||||
withDryRun :: ServantExampleEnv -> ServantExampleEnv
|
||||
withDryRun seEnv = seEnv
|
||||
{ yseMakeClientRequest = \burl req -> yseMakeClientRequest seEnv burl req{ requestHeaders = requestHeaders req Seq.:|> waiCustomHeader HeaderDryRun True }
|
||||
}
|
||||
externalApiCount = runDB $ count @_ @_ @ExternalApi []
|
||||
|
||||
runServantExample ExternalApisR insertExternalApi
|
||||
liftIO . (`shouldBe` 1) =<< externalApiCount
|
||||
|
||||
runServantExample ExternalApisR $ local withDryRun insertExternalApi
|
||||
liftIO . (`shouldBe` 1) =<< externalApiCount
|
||||
36
test/ServantApiSpec.hs
Normal file
36
test/ServantApiSpec.hs
Normal file
@ -0,0 +1,36 @@
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
|
||||
|
||||
module ServantApiSpec where
|
||||
|
||||
import TestImport
|
||||
import ServantApi
|
||||
|
||||
import Servant.API
|
||||
import Servant.API.TypeLevel (MapSub, AppendList)
|
||||
import Foundation.Servant.Types (ApiVersion)
|
||||
|
||||
import GHC.TypeLits
|
||||
import Data.Kind (Constraint)
|
||||
|
||||
|
||||
type family Unversioned api where
|
||||
Unversioned (ApiVersion _ _ _ :> _) = '[]
|
||||
Unversioned (sup :> sub) = MapSub sup (Unversioned sub)
|
||||
Unversioned (a :<|> b) = AppendList (Unversioned a) (Unversioned b)
|
||||
Unversioned (Verb method statusCode contentTypes a) = '[Verb method statusCode contentTypes a]
|
||||
Unversioned (NoContentVerb method) = '[NoContentVerb method]
|
||||
|
||||
type family UnversionedError xs :: ErrorMessage where
|
||||
UnversionedError (x ': '[]) = 'Text "Unversioned API endpoint: " ':$$: ('Text " " ':<>: 'ShowType x)
|
||||
UnversionedError (x ': xs) = UnversionedError (x ': '[]) ':$$: UnversionedError xs
|
||||
|
||||
type family IsEmpty xs :: Constraint where
|
||||
IsEmpty '[] = ()
|
||||
IsEmpty xs = TypeError ('Text "All API endpoints must be versioned." ':$$: UnversionedError xs)
|
||||
|
||||
spec :: Spec
|
||||
spec = describe "Servant endpoints" $ it "are all versioned" versioned
|
||||
where
|
||||
versioned :: IsEmpty (Unversioned UniWorXApi) => Bool
|
||||
versioned = True
|
||||
@ -1,3 +1,5 @@
|
||||
{-# OPTIONS_GHC -fno-warn-deprecations #-}
|
||||
|
||||
module TestImport
|
||||
( module TestImport
|
||||
, module X
|
||||
@ -44,6 +46,34 @@ import Jobs (handleJobs)
|
||||
import Numeric.Natural as X
|
||||
import Network.URI.Arbitrary as X ()
|
||||
|
||||
import qualified Network.Wai as Wai
|
||||
import qualified Network.Wai.Test as Wai
|
||||
import qualified Network.Wai.Test.Internal as Wai (ClientState)
|
||||
import Network.HTTP.Types (Status(..), hContentType, hAccept)
|
||||
import Network.HTTP.Types.Header (hHost)
|
||||
import qualified Network.HTTP.Types as Wai
|
||||
|
||||
import Control.Monad.Trans.Except (ExceptT)
|
||||
import qualified Servant.Client.Core as Servant
|
||||
import Servant.Client.Core.ClientError
|
||||
import Servant.Client.Core.RunClient
|
||||
import Control.Monad.Except (MonadError(..))
|
||||
import Control.Monad.State.Class (MonadState(..))
|
||||
import qualified Control.Monad.State.Class as State
|
||||
import qualified Servant.Types.SourceT as S
|
||||
import Servant.API (SourceIO)
|
||||
|
||||
import Utils (throwExceptT)
|
||||
|
||||
import Yesod.Servant (ServantApi, servantApiBaseUrl)
|
||||
|
||||
import qualified Data.ByteString as BS
|
||||
import qualified Data.ByteString.Lazy as Lazy (ByteString)
|
||||
import qualified Data.ByteString.Lazy as LBS hiding (ByteString)
|
||||
import qualified Data.Binary.Builder as B
|
||||
import Network.HTTP.Media (renderHeader)
|
||||
import Control.Monad.Fail
|
||||
|
||||
import Control.Lens as X hiding ((<.), elements)
|
||||
|
||||
import Network.IP.Addr as X (IP)
|
||||
@ -133,3 +163,105 @@ lawsCheckHspec p = parallel . describe (show $ typeRep p) . mapM_ (checkHspec .
|
||||
where
|
||||
checkHspec (Laws className properties) = describe className $
|
||||
forM_ properties $ \(name, prop) -> it name $ property prop
|
||||
|
||||
|
||||
newtype ServantExample a = ServantExample
|
||||
{ unServantExample :: ReaderT ServantExampleEnv (ExceptT ClientError Wai.Session) a
|
||||
} deriving stock (Generic, Typeable)
|
||||
deriving newtype (Functor, Applicative, Monad, MonadIO, MonadReader ServantExampleEnv, MonadError ClientError, MonadThrow, MonadCatch, MonadState Wai.ClientState)
|
||||
|
||||
data ServantExampleEnv = ServantExampleEnv
|
||||
{ yseBaseUrl :: BaseUrl
|
||||
, yseMakeClientRequest :: BaseUrl -> Servant.Request -> IO Wai.Request
|
||||
} deriving (Generic, Typeable)
|
||||
|
||||
runServantExample :: (Route (ServantApi proxy) -> Route UniWorX) -> ServantExample a -> YesodExample UniWorX a
|
||||
runServantExample apiR (ServantExample act) = do
|
||||
yseBaseUrl <- runHandler $ servantApiBaseUrl apiR
|
||||
let yseMakeClientRequest burl Servant.Request{..} = do
|
||||
((body, bodyLength), contentTypeHdr) <- case requestBody of
|
||||
Nothing -> return ((return BS.empty, Wai.KnownLength 0), Nothing)
|
||||
Just (body', typ) -> let (mkBody, bLength) = convertBody body'
|
||||
in (, Just (hContentType, renderHeader typ)) . (, bLength) <$> mkBody
|
||||
|
||||
return $ Wai.defaultRequest
|
||||
{ Wai.requestMethod = requestMethod
|
||||
, Wai.requestHeaders = maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers
|
||||
, Wai.requestHeaderHost =
|
||||
let BaseUrl{..} = yseBaseUrl
|
||||
in Just . encodeUtf8 . pack $ baseUrlHost <> bool (":" <> show baseUrlPort) mempty (baseUrlPort == 80)
|
||||
, Wai.requestBody = body, Wai.requestBodyLength = bodyLength
|
||||
, Wai.isSecure = isSecure
|
||||
}
|
||||
& flip Wai.setPath (encodeUtf8 (pack $ baseUrlPath burl) <> toStrict (B.toLazyByteString requestPath) <> Wai.renderQuery True (toList requestQueryString))
|
||||
where
|
||||
headers = filter (\(h, _) -> h `notElem` [hAccept, hContentType, hHost]) $ toList requestHeaders
|
||||
|
||||
acceptHdr
|
||||
| null hs = Nothing
|
||||
| otherwise = Just (hAccept, renderHeader hs)
|
||||
where
|
||||
hs = toList requestAccept
|
||||
|
||||
convertBody :: Servant.RequestBody -> (IO (IO ByteString), Wai.RequestBodyLength)
|
||||
convertBody bd = case bd of
|
||||
Servant.RequestBodyLBS body' -> ( givesPopper . S.source . map fromStrict $ LBS.toChunks body'
|
||||
, Wai.KnownLength . fromIntegral $ LBS.length body'
|
||||
)
|
||||
Servant.RequestBodyBS body' -> ( return $ return body'
|
||||
, Wai.KnownLength . fromIntegral $ BS.length body'
|
||||
)
|
||||
Servant.RequestBodySource sourceIO -> ( givesPopper sourceIO
|
||||
, Wai.ChunkedBody
|
||||
)
|
||||
where
|
||||
givesPopper :: SourceIO Lazy.ByteString -> IO (IO ByteString)
|
||||
givesPopper sourceIO = S.unSourceT sourceIO $ \step0 -> do
|
||||
ref <- newMVar step0
|
||||
return $ modifyMVar ref nextBs
|
||||
|
||||
nextBs S.Stop = return (S.Stop, BS.empty)
|
||||
nextBs (S.Error err) = fail err
|
||||
nextBs (S.Skip s) = nextBs s
|
||||
nextBs (S.Effect ms) = ms >>= nextBs
|
||||
nextBs (S.Yield lbs s) = case LBS.toChunks lbs of
|
||||
[] -> nextBs s
|
||||
(x:xs) | BS.null x -> nextBs step'
|
||||
| otherwise -> return (step', x)
|
||||
where
|
||||
step' = S.Yield (LBS.fromChunks xs) s
|
||||
|
||||
isSecure = case baseUrlScheme burl of
|
||||
Servant.Http -> False
|
||||
Servant.Https -> True
|
||||
YesodExampleData waiApp _ _ _ <- State.get
|
||||
liftIO . flip Wai.runSession waiApp . throwExceptT $ runReaderT act ServantExampleEnv{..}
|
||||
|
||||
instance RunClient ServantExample where
|
||||
runRequestAcceptStatus acceptStatus req = do
|
||||
ServantExampleEnv{..} <- ask
|
||||
waiRequest <- liftIO $ yseMakeClientRequest yseBaseUrl req
|
||||
waiResponse@Wai.SResponse{..} <- ServantExample . lift . lift $ Wai.request waiRequest
|
||||
let Status{..} = simpleStatus
|
||||
statusOk = case acceptStatus of
|
||||
Nothing -> 200 <= statusCode && statusCode < 300
|
||||
Just good -> simpleStatus `elem` good
|
||||
response = (waiResponseToResponse waiResponse) { Servant.responseHttpVersion = Wai.httpVersion waiRequest }
|
||||
unless statusOk $
|
||||
throwError $ mkFailureResponse yseBaseUrl req response
|
||||
return response
|
||||
where
|
||||
mkFailureResponse :: BaseUrl -> Servant.Request -> Servant.ResponseF Lazy.ByteString -> ClientError
|
||||
mkFailureResponse burl request' =
|
||||
FailureResponse (bimap (const ()) f request')
|
||||
where
|
||||
f b = (burl, LBS.toStrict $ B.toLazyByteString b)
|
||||
|
||||
waiResponseToResponse :: Wai.SResponse -> Servant.Response
|
||||
waiResponseToResponse Wai.SResponse{..} = Servant.Response
|
||||
{ responseStatusCode = simpleStatus
|
||||
, responseBody = simpleBody
|
||||
, responseHeaders = fromList simpleHeaders
|
||||
, responseHttpVersion = error "WAI Response does not carry http version information"
|
||||
}
|
||||
throwClientError = throwError
|
||||
|
||||
Reference in New Issue
Block a user