diff --git a/package.yaml b/package.yaml index d1f284c52..033a740eb 100644 --- a/package.yaml +++ b/package.yaml @@ -346,6 +346,7 @@ tests: - quickcheck-io - network-arbitrary - lens-properties + - http-media ghc-options: - -fno-warn-orphans - -threaded -rtsopts "-with-rtsopts=-N -T" diff --git a/src/Foundation/Authorization.hs b/src/Foundation/Authorization.hs index 4626a4c53..bbed1a1d9 100644 --- a/src/Foundation/Authorization.hs +++ b/src/Foundation/Authorization.hs @@ -9,6 +9,7 @@ module Foundation.Authorization , wouldHaveReadAccessToIff, wouldHaveWriteAccessToIff , AuthContext(..), getAuthContext , isDryRun, isDryRunDB + , IsDryRun(..) , maybeBearerToken, requireBearerToken , requireCurrentBearerRestrictions, maybeCurrentBearerRestrictions , BearerAuthSite, MonadAP @@ -276,7 +277,9 @@ getAuthContext = liftHandler $ do return authCtx newtype IsDryRun = MkIsDryRun { unIsDryRun :: Bool } - deriving (Eq, Ord, Read, Show, Generic, Typeable) + deriving stock (Read, Show, Generic, Typeable) + deriving newtype (Eq, Ord) + deriving (Semigroup, Monoid) via Any isDryRun :: ( HasCallStack , BearerAuthSite UniWorX @@ -296,6 +299,7 @@ isDryRunDB :: forall m backend m'. isDryRunDB = fmap unIsDryRun . cached . fmap MkIsDryRun $ orM [ hasGlobalPostParam PostDryRun , hasGlobalGetParam GetDryRun + , hasCustomHeader HeaderDryRun , and2M bearerDryRun bearerRequired ] where diff --git a/src/Foundation/Servant.hs b/src/Foundation/Servant.hs index e522d9094..380703d79 100644 --- a/src/Foundation/Servant.hs +++ b/src/Foundation/Servant.hs @@ -9,19 +9,15 @@ module Foundation.Servant import Import.Servant.NoFoundation import Foundation.DB (runSqlPoolRetry') -import Foundation.Authorization (maybeBearerToken) +import Foundation.Authorization (maybeBearerToken, IsDryRun(..), isDryRun) import Foundation.Instances () import qualified Data.HashMap.Strict.InsOrd as HashMap -import Network.Wai (Middleware, modifyResponse, mapResponseHeaders, vault) +import Network.Wai (Middleware, modifyResponse, mapResponseHeaders) import qualified Network.Wai as W -import qualified Data.Vault.Lazy as Vault - -import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal, withRequest) - -import System.IO.Unsafe (unsafePerformIO) +import Servant.Server.Internal.DelayedIO (DelayedIO, delayedFail, delayedFailFatal) import qualified Yesod.Servant as Servant @@ -32,21 +28,16 @@ import Control.Monad.Catch.Pure import Servant.Server.Internal.Delayed import Servant.Server.Internal.Router --- import Database.Persist.Sql (transactionUndo) +import Database.Persist.Sql (transactionUndo) - -waiBearerKey :: Vault.Key (Maybe (BearerToken UniWorX)) -waiBearerKey = unsafePerformIO Vault.newKey -{-# NOINLINE waiBearerKey #-} - -waiRouteKey :: Vault.Key (Route UniWorX) -waiRouteKey = unsafePerformIO Vault.newKey -{-# NOINLINE waiRouteKey #-} +import qualified Data.CaseInsensitive as CI instance ( HasServer sub context , ToJSON restr, FromJSON restr , SBoolI (FoldRequired mods) + , HasContextEntry context (Maybe (BearerToken UniWorX)) + , HasContextEntry context (Maybe (Route UniWorX)) ) => HasServer (CaptureBearerRestriction' mods restr :> sub) context where @@ -56,25 +47,25 @@ instance ( HasServer sub context hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver - = route (Proxy @sub) context (subserver `addAuthCheck` withRequest bearerCheck) + = route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck) where - bearerCheck :: W.Request -> DelayedIO (RequiredArgument mods restr) - bearerCheck req = do - let bearer = Vault.lookup waiBearerKey $ vault req - cRoute = Vault.lookup waiRouteKey $ vault req + bearerCheck :: DelayedIO (RequiredArgument mods restr) + bearerCheck = do + let bearer :: Maybe (BearerToken UniWorX) + bearer = getContextEntry context + cRoute :: Maybe (Route UniWorX) + cRoute = getContextEntry context - noRouteStored, noTokenStored, noTokenProvided, noRestrictionProvided :: ServerError - noTokenStored = err500 { errBody = "servantYesodMiddleware did not store bearer token in WAI vault." } + noRouteStored, noTokenProvided, noRestrictionProvided :: ServerError noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." } noRestrictionProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor the provided bearer token must contain a restriction entry for this route." } noRouteStored = err500 { errBody = "servantYesodMiddleware did not store current route in WAI vault." } exceptT delayedFailFatal return $ do - bearer' <- maybeExceptT' noTokenStored bearer cRoute' <- maybeExceptT' noRouteStored cRoute let mbRet :: Maybe (Maybe restr) - mbRet = bearer' <&> preview (_bearerRestrictionIx cRoute') + mbRet = bearer <&> preview (_bearerRestrictionIx cRoute') case sbool @(FoldRequired mods) of SFalse -> return $ join mbRet STrue -> maybe (throwE noTokenProvided) (maybe (throwE noRestrictionProvided) return) mbRet @@ -82,6 +73,7 @@ instance ( HasServer sub context instance ( HasServer sub context , SBoolI (FoldRequired mods) + , HasContextEntry context (Maybe (BearerToken UniWorX)) ) => HasServer (CaptureBearerToken' mods :> sub) context where @@ -91,21 +83,20 @@ instance ( HasServer sub context hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @sub) pc nt . s route _ context subserver - = route (Proxy @sub) context (subserver `addAuthCheck` withRequest bearerCheck) + = route (Proxy @sub) context (subserver `addAuthCheck` bearerCheck) where - bearerCheck :: W.Request -> DelayedIO (RequiredArgument mods (BearerToken UniWorX)) - bearerCheck req = do - let bearer = Vault.lookup waiBearerKey $ vault req + bearerCheck :: DelayedIO (RequiredArgument mods (BearerToken UniWorX)) + bearerCheck = do + let bearer :: Maybe (BearerToken UniWorX) + bearer = getContextEntry context - noTokenStored, noTokenProvided :: ServerError - noTokenStored = err500 { errBody = "servantYesodMiddleware did not store bearer token in WAI vault." } + noTokenProvided :: ServerError noTokenProvided = err400 { errBody = "The behaviour of this route depends on restrictions stored in the bearer token used for authorization. Therefor providing a bearer token is required." } exceptT delayedFailFatal return $ do - bearer' <- maybeExceptT' noTokenStored bearer case sbool @(FoldRequired mods) of - SFalse -> return bearer' - STrue -> maybe (throwE noTokenProvided) return bearer' + SFalse -> return bearer + STrue -> maybe (throwE noTokenProvided) return bearer instance ( HasServer sub context @@ -132,23 +123,40 @@ instance ( HasServer sub context decrypt' inp = left tshow . runCatch . runReaderT (decrypt inp) . appCryptoIDKey $ getContextEntry context -type UniWorXContext = UniWorX ': '[] +type UniWorXContext = Maybe (Route UniWorX) ': Maybe (BearerToken UniWorX) ': IsDryRun ': UniWorX ': '[] type ServantHandler = ServantHandlerFor UniWorX type ServantDB = ServantDBFor UniWorX deriving via (ServantLogYesod UniWorX) instance ServantLog UniWorX +instance HasServantHandlerContext UniWorX where + data ServantHandlerContextFor UniWorX = ServantHandlerContextForUniWorX + { usctxSite :: UniWorX + , usctxRequest :: W.Request + , usctxIsDryRun :: IsDryRun + } + getSCtxSite = usctxSite + getSCtxRequest = usctxRequest + class (HasServer (ServantApiUnproxy' proxy) UniWorXContext, Servant.HasRoute (ServantApiUnproxy' proxy)) => ServantApiDispatchUniWorX proxy where servantServer' :: ServantApi proxy -> ServerT (ServantApiUnproxy' proxy) ServantHandler instance ServantApiDispatchUniWorX proxy => ServantApiDispatch UniWorXContext ServantHandler UniWorX proxy where - servantContext _ app _ = return $ app :. EmptyContext - servantHoist _ sctxSite sctxRequest _ = ($ ServantHandlerContextFor{..}) . unServantHandlerFor - servantMiddleware _ _ _ = modifyResponse (mapResponseHeaders setDefaultHeaders) . fixTrailingSlash - servantYesodMiddleware _ _ = appEndo <$> foldMapM (fmap Endo) [storeBearerToken, storeCurrentRoute] + servantContext _ app _ = do + isDryRun' <- MkIsDryRun <$> isDryRun + restr <- maybeBearerToken + cRoute <- getCurrentRoute + return $ cRoute :. restr :. isDryRun' :. app :. EmptyContext + servantHoist _ usctxSite usctxRequest ctx = ($ ServantHandlerContextForUniWorX{ usctxIsDryRun = getContextEntry ctx, .. }) . unServantHandlerFor + servantMiddleware _ _ ctx = appEndo . foldMap Endo $ + guardOn (unIsDryRun $ getContextEntry ctx) (modifyResponse $ mapResponseHeaders setDryRunHeader) + ++ [ modifyResponse (mapResponseHeaders setDefaultHeaders) + , fixTrailingSlash + ] + servantYesodMiddleware _ _ = return id servantServer proxy _ = servantServer' proxy -setDefaultHeaders :: ResponseHeaders -> ResponseHeaders +setDefaultHeaders, setDryRunHeader :: ResponseHeaders -> ResponseHeaders setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defaultHeaders where defaultHeaders = HashMap.fromList [ ("X-Frame-Options", "sameorigin") @@ -156,6 +164,7 @@ setDefaultHeaders existing = HashMap.toList $ HashMap.fromList existing <> defau , ("Vary", "Accept") , ("X-XSS-Protection", "1; mode=block") ] +setDryRunHeader existing = HashMap.toList $ HashMap.fromList existing <> HashMap.singleton (CI.mk . encodeUtf8 $ toPathPiece HeaderDryRun) (encodeUtf8 $ toPathPiece True) fixTrailingSlash :: Middleware -- ^ `servant-server` contains a special case in their implementation @@ -170,17 +179,6 @@ fixTrailingSlash = (. fixTrailingSlash') | otherwise = req -storeBearerToken, storeCurrentRoute :: HandlerFor UniWorX Middleware -storeBearerToken = do - restr <- maybeBearerToken - return $ \app req -> app req{ vault = Vault.insert waiBearerKey restr $ vault req } -storeCurrentRoute = do - cRoute <- getCurrentRoute - - $logDebugS "storeCurrentRoute" $ tshow cRoute - - return $ \app req -> app req{ vault = maybe id (Vault.insert waiRouteKey) cRoute $ vault req } - instance ServantPersist UniWorX where runDB :: HasCallStack => ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a @@ -189,9 +187,9 @@ instance ServantPersist UniWorX where runDB' :: CallStack -> ServantDBFor UniWorX a -> ServantHandlerFor UniWorX a runDB' lbl action = do $logDebugS "ServantPersist" "runDB" - -- let action' = do - -- dryRun <- isDryRunDB - -- if | dryRun -> action <* transactionUndo - -- | otherwise -> action + MkIsDryRun dryRun <- getsServantContext usctxIsDryRun + let action' + | dryRun = action <* transactionUndo + | otherwise = action - flip (runSqlPoolRetry' action) lbl . appConnPool =<< getSite + flip (runSqlPoolRetry' action') lbl . appConnPool =<< getSite diff --git a/src/Foundation/Servant/Types.hs b/src/Foundation/Servant/Types.hs index 3e4c8b4d1..f10462d14 100644 --- a/src/Foundation/Servant/Types.hs +++ b/src/Foundation/Servant/Types.hs @@ -5,12 +5,14 @@ module Foundation.Servant.Types , CaptureBearerToken, CaptureBearerToken' , CaptureCryptoID', CaptureCryptoID, CaptureCryptoUUID, CaptureCryptoFileName , ApiVersion, apiVersionToSemVer, matchesApiVersion + , BearerAuth, SessionAuth ) where -import ClassyPrelude +import ClassyPrelude hiding (fromList) import Data.Proxy import Servant.API +import Servant.API.Modifiers (FoldRequired) import Servant.API.Description import Servant.Swagger import Servant.Docs @@ -21,6 +23,13 @@ import Servant.Server.Internal.Delayed import Servant.Server.Internal.ErrorFormatter -- import Servant.Server.Internal.DelayedIO +import Servant.Client.Core.RunClient (RunClient) +import Servant.Client.Core.HasClient +import qualified Servant.Client.Core.Request as Servant (Request) +import qualified Servant.Client.Core.Request as Request + +import Jose.Jwt (Jwt(..)) + import Network.Wai (mapResponseHeaders, requestHeaders) import Control.Lens hiding (Context) @@ -31,8 +40,9 @@ import Data.CryptoID.Class.ImplicitNamespace import Data.CryptoID.Instances () import GHC.TypeLits +import GHC.Exts (IsList(..)) -import Data.Swagger (ToParamSchema) +import Data.Swagger hiding (version) import Data.Kind (Type) @@ -114,6 +124,11 @@ instance HasDocs sub => HasDocs (CaptureBearerToken' mods :> sub) where instance (ToCapture (Capture sym ciphertext), KnownSymbol sym, HasDocs sub) => HasDocs (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where docsFor _ = docsFor $ Proxy @(Capture' mods sym ciphertext :> sub) +instance (RunClient m, HasClient m (Capture' mods sym (CryptoID ciphertext plaintext) :> sub)) => HasClient m (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where + type Client m (CaptureCryptoID' mods ciphertext sym plaintext :> sub) = Client m (Capture' mods sym (CryptoID ciphertext plaintext) :> sub) + clientWithRoute pm _ = clientWithRoute pm $ Proxy @(Capture' mods sym (CryptoID ciphertext plaintext) :> sub) + hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(Capture' mods sym (CryptoID ciphertext plaintext) :> sub) + type family ApiVersionSub major minor patch sup sub where ApiVersionSub major minor patch (ApiVersion major' minor' patch') sub = ApiVersion major' minor' patch' :> sub @@ -143,8 +158,30 @@ instance ( HasServer (ApiVersion major minor patch :> a) context choice' = case (sbool :: SBool (IsLT (CmpVersion (FinalApiVersion (ApiVersion major minor patch :> a)) (FinalApiVersion (ApiVersion major minor patch :> b))))) of STrue -> flip choice SFalse -> choice + +instance (RunClient m, HasClient m (ApiVersionSub major minor patch sup sub)) => HasClient m (ApiVersion major minor patch :> ((sup :: Type) :> sub)) where + type Client m (ApiVersion major minor patch :> (sup :> sub)) = Client m (ApiVersionSub major minor patch sup sub) + clientWithRoute pm _ = clientWithRoute pm $ Proxy @(ApiVersionSub major minor patch sup sub) + hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(ApiVersionSub major minor patch sup sub) + +instance (RunClient m, HasClient m (sup :> (ApiVersion major minor patch :> sub))) => HasClient m (ApiVersion major minor patch :> ((sup :: Symbol) :> sub)) where + type Client m (ApiVersion major minor patch :> (sup :> sub)) = Client m (sup :> (ApiVersion major minor patch :> sub)) + clientWithRoute pm _ = clientWithRoute pm $ Proxy @(sup :> (ApiVersion major minor patch :> sub)) + hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(sup :> (ApiVersion major minor patch :> sub)) + +instance ( HasClient m (ApiVersion major minor patch :> a) + , HasClient m (ApiVersion major minor patch :> b) + ) => HasClient m (ApiVersion major minor patch :> (a :<|> b)) where + type Client m (ApiVersion major minor patch :> (a :<|> b)) = Client m (ApiVersion major minor patch :> a) :<|> Client m (ApiVersion major minor patch :> b) + clientWithRoute pm _ req = clientWithRoute pm (Proxy @(ApiVersion major minor patch :> a)) req + :<|> clientWithRoute pm (Proxy @(ApiVersion major minor patch :> b)) req + hoistClientMonad pm _ f (ca :<|> cb) = hoistClientMonad pm (Proxy @(ApiVersion major minor patch :> a)) f ca + :<|> hoistClientMonad pm (Proxy @(ApiVersion major minor patch :> b)) f cb +versionRequestHeaderName :: CI ByteString +versionRequestHeaderName = "Accept-API-Version" + routeWithApiVersion :: forall api context env major minor patch. ( HasServer api context , KnownNat major, KnownNat minor, KnownNat patch @@ -168,7 +205,6 @@ routeWithApiVersion _ _ context subserver = RawRouter $ \env req ((. addVersion) version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch) versionHeaderName = "API-Version" - versionRequestHeaderName = "Accept-API-Version" versionHeader = encodeUtf8 $ SemVer.toText version notFound = notFoundErrorFormatter . getContextEntry $ mkContextWithErrorFormatter context @@ -194,6 +230,26 @@ instance ( HasServer (NoContentVerb method) context route _ = routeWithApiVersion (Proxy @(ApiVersion major minor patch)) (Proxy @(NoContentVerb method)) +semVerCompatibleTo :: SemVer.Version -> SemVer.Constraint +semVerCompatibleTo v = SemVer.Constraint.CAnd (SemVer.Constraint.CGtEq v) (SemVer.Constraint.CLt $ SemVer.incrementMajor v) + +instance ( HasClient m (Verb method statusCode contentTypes a) + , KnownNat major, KnownNat minor, KnownNat patch + ) => HasClient m (ApiVersion major minor patch :> Verb method statusCode contentTypes a) where + type Client m (ApiVersion major minor patch :> Verb method statusCode contentTypes a) = Client m (Verb method statusCode contentTypes a) + clientWithRoute pm _ = clientWithRoute pm (Proxy @(Verb method statusCode contentTypes a)) . Request.addHeader versionRequestHeaderName (semVerCompatibleTo version) + where version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch) + hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(Verb method statusCode contentTypes a) + +instance ( HasClient m (NoContentVerb method) + , KnownNat major, KnownNat minor, KnownNat patch + ) => HasClient m (ApiVersion major minor patch :> NoContentVerb method) where + type Client m (ApiVersion major minor patch :> NoContentVerb method) = Client m (NoContentVerb method) + clientWithRoute pm _ = clientWithRoute pm (Proxy @(NoContentVerb method)) . Request.addHeader versionRequestHeaderName (semVerCompatibleTo version) + where version = apiVersionToSemVer $ Proxy @(ApiVersion major minor patch) + hoistClientMonad pm _ = hoistClientMonad pm $ Proxy @(NoContentVerb method) + + instance ( HasDocs (ApiVersionSub major minor patch sup sub) ) => HasDocs (ApiVersion major minor patch :> ((sup :: Type) :> sub)) where docsFor _ = docsFor $ Proxy @(ApiVersionSub major minor patch sup sub) @@ -263,3 +319,105 @@ type family IsLT x where type instance IsElem' sa (CaptureCryptoID' mods ciphertext sym plaintext :> sb) = IsElem sa (Capture' mods sym (CryptoID ciphertext plaintext) :> sb) type instance IsElem' sa (ApiVersion major minor patch :> sb) = IsElem sa sb + + +type family StripBearer api where + StripBearer (CaptureBearerRestriction' mods restr :> sub) = sub + StripBearer (CaptureBearerToken' mods :> sub) = sub + StripBearer (BearerAuth :> sub) = sub + StripBearer (sup :> sub) = sup :> StripBearer sub + StripBearer (a :<|> b) = StripBearer a :<|> StripBearer b + StripBearer (Verb method statusCode contentTypes a) = Verb method statusCode contentTypes a + StripBearer (NoContentVerb method) = NoContentVerb method + +type family BearerRequired api where + BearerRequired (CaptureBearerRestriction' mods restr :> sub) = OrBool (FoldRequired mods) (BearerRequired sub) + BearerRequired (CaptureBearerToken' mods :> sub) = OrBool (FoldRequired mods) (BearerRequired sub) + BearerRequired (BearerAuth :> sub) = 'True + BearerRequired (sup :> sub) = BearerRequired sub + BearerRequired (a :<|> b) = OrBool (BearerRequired a) (BearerRequired b) + BearerRequired (Verb method statusCode contentTypes a) = 'False + BearerRequired (NoContentVerb method) = 'False + +type family OrBool a b where + OrBool 'False 'False = 'False + OrBool a b = 'True + +maybeWithJwt :: forall (a :: Bool). SBoolI a => Proxy a -> If a Jwt (Maybe Jwt) -> Servant.Request -> Servant.Request +maybeWithJwt _ mparam = case (sbool :: SBool a, mparam) of + (STrue, jwt) -> add jwt + (SFalse, mJwt) -> maybe id add mJwt + where add (Jwt jwt) = Request.addHeader "Authorization" . decodeUtf8 $ "Bearer " <> jwt + +instance ( HasClient m (StripBearer sub) + , RunClient m + , SBoolI (BearerRequired (CaptureBearerRestriction' mods restr :> sub)) + ) => HasClient m (CaptureBearerRestriction' mods restr :> sub) where + type Client m (CaptureBearerRestriction' mods restr :> sub) = If (BearerRequired (CaptureBearerRestriction' mods restr :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub) + clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (CaptureBearerRestriction' mods restr :> sub))) mparam req + hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl + +instance ( HasClient m (StripBearer sub) + , RunClient m + , SBoolI (BearerRequired (CaptureBearerToken' mods :> sub)) + ) => HasClient m (CaptureBearerToken' mods :> sub) where + type Client m (CaptureBearerToken' mods :> sub) = If (BearerRequired (CaptureBearerToken' mods :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub) + clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (CaptureBearerToken' mods :> sub))) mparam req + hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl + +instance ( HasClient m (StripBearer sub) + , RunClient m + , SBoolI (BearerRequired (BearerAuth :> sub)) + ) => HasClient m (BearerAuth :> sub) where + type Client m (BearerAuth :> sub) = If (BearerRequired (BearerAuth :> sub)) Jwt (Maybe Jwt) -> Client m (StripBearer sub) + clientWithRoute pm _ req mparam = clientWithRoute pm (Proxy @(StripBearer sub)) $ maybeWithJwt (Proxy @(BearerRequired (BearerAuth :> sub))) mparam req + hoistClientMonad pm _ f cl = hoistClientMonad pm (Proxy @(StripBearer sub)) f . cl + + +data BearerAuth +data SessionAuth + +instance HasSwagger sub => HasSwagger (BearerAuth :> sub) where + toSwagger _ = toSwagger (Proxy @sub) + & securityDefinitions <>~ SecurityDefinitions (fromList [(defnKey, defn)]) + & allOperations . security <>~ [SecurityRequirement $ fromList [(defnKey, [])]] + where defnKey :: Text + defnKey = "bearer" + defn = SecurityScheme + { _securitySchemeType + = SecuritySchemeApiKey ApiKeyParams + { _apiKeyName = "Authorization" + , _apiKeyIn = ApiKeyHeader + } + , _securitySchemeDescription = Just + "JSON Web Token-based API key" + } + +instance HasSwagger sub => HasSwagger (SessionAuth :> sub) where + toSwagger _ = toSwagger (Proxy @sub) + & allOperations . security <>~ [SecurityRequirement mempty] + -- We do not expect API clients to be able/willing to conform with + -- our CSRF mitigation, so we mark routes that require it as + -- having unfullfillable security requirements + +instance HasLink sub => HasLink (BearerAuth :> sub) where + type MkLink (BearerAuth :> sub) a = MkLink sub a + toLink toA _ = toLink toA (Proxy @sub) + +instance HasLink sub => HasLink (SessionAuth :> sub) where + type MkLink (SessionAuth :> sub) a = MkLink sub a + toLink toA _ = toLink toA (Proxy @sub) + +instance HasDocs sub => HasDocs (BearerAuth :> sub) where + docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action') + where action' = action & authInfo %~ (|> authInfo') + authInfo' = DocAuthentication + "" + "A JSON Web Token-based API key" + +instance HasDocs sub => HasDocs (SessionAuth :> sub) where + docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action') + where action' = action & authInfo %~ (|> authInfo') + authInfo' = DocAuthentication + "When a web session is used for authorization, CSRF-mitigation measures must be observed." + "An active web session identifying the user as one with sufficient authorization" diff --git a/src/Import/NoModel.hs b/src/Import/NoModel.hs index fc51f0302..06a8e4af7 100644 --- a/src/Import/NoModel.hs +++ b/src/Import/NoModel.hs @@ -40,7 +40,7 @@ import Yesod.Core.Types as Import (loggerSet) import Yesod.Default.Config2 as Import import Yesod.Core.Types.Instances as Import import Yesod.Servant as Import - hiding ( MonadHandler(..), HasRoute(..) + hiding ( MonadHandler(..), HasRoute(..), MonadRequest(..) , runDB, defaultRunDB ) import Servant.Docs as Import @@ -210,6 +210,7 @@ import Data.MonoTraversable.Instances as Import () import Servant.Client.Core.BaseUrl.Instances as Import () import Control.Monad.Trans.Except.Instances as Import () import Servant.Server.Instances as Import () +import Servant.Docs.Internal.Pretty.Instances as Import () import Network.URI.Instances as Import () import Data.HashSet.Instances as Import () import Web.Cookie.Instances as Import () diff --git a/src/Import/Servant/NoFoundation.hs b/src/Import/Servant/NoFoundation.hs index 7e8c66b84..c9b4e06c6 100644 --- a/src/Import/Servant/NoFoundation.hs +++ b/src/Import/Servant/NoFoundation.hs @@ -14,6 +14,7 @@ import Import.NoFoundation as Import hiding , MonadHandler(..), HasRoute(..), liftHandler , encrypt, decrypt , Unique, Fragment(..), respond + , getRequest ) import Yesod.Servant as Import diff --git a/src/Servant/Docs/Internal/Pretty/Instances.hs b/src/Servant/Docs/Internal/Pretty/Instances.hs new file mode 100644 index 000000000..24b761d96 --- /dev/null +++ b/src/Servant/Docs/Internal/Pretty/Instances.hs @@ -0,0 +1,14 @@ +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Servant.Docs.Internal.Pretty.Instances () where + +import ClassyPrelude + +import Servant.Docs.Internal.Pretty +import Servant.API.ContentTypes + +import Data.Proxy + + +instance MimeUnrender JSON a => MimeUnrender PrettyJSON a where + mimeUnrender _ = mimeUnrender $ Proxy @JSON diff --git a/src/ServantApi/ExternalApis/Type.hs b/src/ServantApi/ExternalApis/Type.hs index 181f2bca2..4a1b6be51 100644 --- a/src/ServantApi/ExternalApis/Type.hs +++ b/src/ServantApi/ExternalApis/Type.hs @@ -15,17 +15,22 @@ import Jose.Jwk (JwkSet(..)) {-# ANN module ("HLint: ignore Use newtype instead of data" :: String) #-} -type ExternalApisListR = Get '[PrettyJSON] ExternalApisList -type ExternalApisCreateR = CaptureBearerRestriction' '[Optional] ExternalApiCreationRestrictions +type ExternalApisListR = ApiVersion 1 0 0 + :> Get '[PrettyJSON] ExternalApisList +type ExternalApisCreateR = ApiVersion 1 0 0 + :> CaptureBearerRestriction' '[Optional] ExternalApiCreationRestrictions :> CaptureBearerToken :> ReqBody '[JSON] ExternalApiCreationRequest :> PostCreated '[PrettyJSON] (Headers '[Header "Location" URI] ExternalApiCreationResponse) -type ExternalApisPongR = CaptureCryptoUUID "external-api" ExternalApiId +type ExternalApisPongR = ApiVersion 1 0 0 + :> CaptureCryptoUUID "external-api" ExternalApiId :> "pong" :> Post '[PrettyJSON] ExternalApiPongResponse -type ExternalApisInfoR = CaptureCryptoUUID "external-api" ExternalApiId +type ExternalApisInfoR = ApiVersion 1 0 0 + :> CaptureCryptoUUID "external-api" ExternalApiId :> Get '[PrettyJSON] ExternalApiInfo -type ExternalApisDeleteR = CaptureCryptoUUID "external-api" ExternalApiId +type ExternalApisDeleteR = ApiVersion 1 0 0 + :> CaptureCryptoUUID "external-api" ExternalApiId :> DeleteNoContent data ExternalApis mode = ExternalApis @@ -37,7 +42,7 @@ data ExternalApis mode = ExternalApis } deriving (Generic) type ServantApiExternalApis = ServantApi ExternalApis -type instance ServantApiUnproxy ExternalApis = ApiVersion 1 0 0 :> ToServantApi ExternalApis +type instance ServantApiUnproxy ExternalApis = ToServantApi ExternalApis instance ToCapture (Capture "external-api" UUID) where @@ -122,7 +127,9 @@ data ExternalApiInfo = ExternalApiInfo instance ToJSON ExternalApiInfo where toJSON ExternalApiInfo{..} = object $ maybe id ((:) . ("ident" .=)) eaiIdent - [ "token-authority" .= foldMap (HashSet.singleton . either id toJSON) eaiTokenAuthority + [ "token-authority" .= case HashSet.toList eaiTokenAuthority of + [x] -> either id toJSON x + _ -> toJSON $ foldMap (HashSet.singleton . either id toJSON) eaiTokenAuthority , "token-issued" .= eaiTokenIssued , "token-expires-at" .= eaiTokenExpiresAt , "token-starts-at" .= eaiTokenStartsAt @@ -134,11 +141,11 @@ instance ToJSON ExternalApiInfo where instance FromJSON ExternalApiInfo where parseJSON = withObject "ExternalApiInfo" $ \o -> do - eaiIdent <- o .:? "token-authority" + eaiIdent <- o .:? "ident" eaiTokenAuthority <- asum - [ HashSet.singleton . Right <$> o .: "authority" - , (o .: "authority" :: _ (HashSet Value)) >>= foldMapM (\v' -> fmap HashSet.singleton $ (Right <$> parseJSON v') <|> return (Left v')) - , HashSet.singleton . Left <$> o .: "authority" + [ HashSet.singleton . Right <$> o .: "token-authority" + , (o .: "token-authority" :: _ (HashSet Value)) >>= foldMapM (\v' -> fmap HashSet.singleton $ (Right <$> parseJSON v') <|> return (Left v')) + , HashSet.singleton . Left <$> o .: "token-authority" ] eaiTokenIssued <- o .: "token-issued" eaiTokenExpiresAt <- o .: "token-expires-at" diff --git a/src/Utils.hs b/src/Utils.hs index 862ae9cc9..4b8f8de89 100644 --- a/src/Utils.hs +++ b/src/Utils.hs @@ -113,7 +113,9 @@ import Data.Binary (Binary) import qualified Data.Binary as Binary import Network.Wai (requestMethod) -import Network.HTTP.Types.Header +import Network.HTTP.Types.Header as Wai + +import Web.HttpApiData import Data.Time.Clock @@ -1143,6 +1145,9 @@ addCustomHeader, replaceOrAddCustomHeader :: (MonadHandler m, PathPiece payload) addCustomHeader ident payload = addHeader (toPathPiece ident) (toPathPiece payload) replaceOrAddCustomHeader ident payload = replaceOrAddHeader (toPathPiece ident) (toPathPiece payload) +waiCustomHeader :: ToHttpApiData payload => CustomHeader -> payload -> Wai.Header +waiCustomHeader ident payload = (CI.mk . encodeUtf8 $ toPathPiece ident, toHeader payload) + ------------------ -- HTTP Headers -- ------------------ diff --git a/src/Yesod/Servant.hs b/src/Yesod/Servant.hs index 2f2e567d4..3abe2732d 100644 --- a/src/Yesod/Servant.hs +++ b/src/Yesod/Servant.hs @@ -8,8 +8,8 @@ module Yesod.Servant , ServantApiDispatch(..) , servantApiLink , ServantHandlerFor(..) - , ServantHandlerContextFor(..), getServantContext, getsServantContext, getYesodApproot, renderRouteAbsolute - , MonadServantHandler(..), MonadHandler(..), MonadSite(..) + , HasServantHandlerContext(..), getServantContext, getsServantContext, getYesodApproot, renderRouteAbsolute, servantApiBaseUrl + , MonadServantHandler(..), MonadHandler(..), MonadSite(..), MonadRequest(..) , ServantDBFor, ServantPersist(..), defaultRunDB , ServantLog(..), ServantLogYesod(..) , mkYesodApi @@ -45,6 +45,8 @@ import Servant.API import Servant.Server hiding (route) import Servant.Server.Instances () +import Servant.Client.Core.BaseUrl + import Data.Proxy import Network.Wai (Request, Middleware) @@ -59,13 +61,10 @@ import Control.Monad.Fail (MonadFail(..)) import Data.Data (Data) import Data.Kind (Type) -import GHC.Exts (IsList(..), Constraint) +import GHC.Exts (Constraint) -import Servant.Swagger import Data.Swagger -import Servant.Docs - import qualified Data.Set as Set import Network.HTTP.Types.Status @@ -138,75 +137,92 @@ instance (Typeable m, Typeable k, Typeable status, Typeable fr, Typeable ct, Typ instance HasRoute sub => HasRoute (HttpVersion :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(HttpVersion :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance HasRoute sub => HasRoute (Vault :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Vault :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, KnownSymbol realm, Typeable a) => HasRoute (BasicAuth realm a :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(BasicAuth realm a :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, KnownSymbol s) => HasRoute (Description s :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Description s :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, KnownSymbol s) => HasRoute (Summary s :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Summary s :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, Typeable tag, Typeable k) => HasRoute (AuthProtect (tag :: k) :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(AuthProtect tag :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance HasRoute sub => HasRoute (IsSecure :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(IsSecure :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance HasRoute sub => HasRoute (RemoteHost :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(RemoteHost :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, Typeable mods, Typeable restr) => HasRoute (CaptureBearerRestriction' mods restr :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(CaptureBearerRestriction' mods restr :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, Typeable mods) => HasRoute (CaptureBearerToken' mods :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(CaptureBearerToken' mods :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (KnownSymbol sym, HasRoute sub, HasLink sub) => HasRoute (sym :> sub) where parseServantRoute (p : ps, qs) | p == escapedSymbol (Proxy @sym) = parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(sym :> endpoint)) f (escapedSymbol (Proxy @sym) : ps') qs' + ServantApiBaseRoute -> ServantApiBaseRoute parseServantRoute _ = Nothing instance (HasRoute a, HasRoute b) => HasRoute (a :<|> b) where parseServantRoute args = asum [ parseServantRoute @a @(ServantApiDirect a) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @endpoint) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute , parseServantRoute @b @(ServantApiDirect b) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @endpoint) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute ] instance (HasRoute sub, Typeable mods, Typeable ct, Typeable a) => HasRoute (ReqBody' mods ct a :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(ReqBody' mods ct a :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, Typeable mods, Typeable framing, Typeable ct, Typeable a) => HasRoute (StreamBody' mods framing ct a :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(StreamBody' mods framing ct a :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, KnownSymbol sym, Typeable mods, Typeable a) => HasRoute (Header' mods sym (a :: Type) :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(Header' mods sym a :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable v, ToHttpApiDataInjective v, FromHttpApiData v) => HasRoute (Capture' mods sym (v :: Type) :> sub) where parseServantRoute (p : ps, qs) | Right v <- parseUrlPiece @v p = parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(Capture' mods sym v :> endpoint)) (f . ($ v)) (toUrlPieceInjective v : ps') qs' + ServantApiBaseRoute -> ServantApiBaseRoute parseServantRoute _ = Nothing instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable plaintext, ToHttpApiDataInjective ciphertext, FromHttpApiData ciphertext, Typeable ciphertext) => HasRoute (CaptureCryptoID' mods ciphertext sym plaintext :> sub) where @@ -214,11 +230,13 @@ instance (HasRoute sub, Typeable mods, KnownSymbol sym, Typeable plaintext, ToHt | Right v <- parseUrlPiece @(CryptoID ciphertext plaintext) p = parseServantRoute @sub @(ServantApiDirect sub) (ps, qs) <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps' qs' -> ServantApiRoute (Proxy @(CaptureCryptoID' mods ciphertext sym plaintext :> endpoint)) (f . ($ v)) (toUrlPieceInjective v : ps') qs' + ServantApiBaseRoute -> ServantApiBaseRoute parseServantRoute _ = Nothing instance (HasRoute sub, KnownNat major, KnownNat minor, KnownNat patch) => HasRoute (ApiVersion major minor patch :> sub) where parseServantRoute args = parseServantRoute @sub @(ServantApiDirect sub) args <&> \case ServantApiRoute (_ :: Proxy endpoint) f ps qs -> ServantApiRoute (Proxy @(ApiVersion major minor patch :> endpoint)) f ps qs + ServantApiBaseRoute -> ServantApiBaseRoute data ServantApi (proxy :: k) = ServantApi @@ -249,45 +267,56 @@ instance HasRoute (ServantApiUnproxy' proxy) => RenderRoute (ServantApi proxy) w (Proxy endpoint) (forall a. MkLink endpoint a -> a) [Text] (HashMap Text [Text]) + | ServantApiBaseRoute renderRoute (ServantApiRoute (_ :: Proxy endpoint) f _ _) = f $ safeLink' renderServantRoute (Proxy @(ServantApiUnproxy' proxy)) (Proxy @endpoint) + renderRoute ServantApiBaseRoute = mempty instance HasRoute (ServantApiUnproxy' proxy) => Eq (Route (ServantApi proxy)) where (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) == (ServantApiRoute (_ :: Proxy endpoint') _ ps' qs') = case eqT @endpoint @endpoint' of Just Refl -> ps == ps' && qs == qs' Nothing -> False + ServantApiBaseRoute == ServantApiBaseRoute = True + _ == _ = False instance HasRoute (ServantApiUnproxy' proxy) => Ord (Route (ServantApi proxy)) where compare (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) (ServantApiRoute (_ :: Proxy endpoint') _ ps' qs') = case eqT @endpoint @endpoint' of Just Refl -> compare ps ps' <> compare qs qs' Nothing -> typeRep (Proxy @endpoint) `compare` typeRep (Proxy @endpoint') + compare ServantApiBaseRoute ServantApiBaseRoute = EQ + compare ServantApiBaseRoute _ = LT + compare _ ServantApiBaseRoute = GT instance HasRoute (ServantApiUnproxy' proxy) => Hashable (Route (ServantApi proxy)) where - hashWithSalt salt (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = salt `hashWithSalt` typeRep (Proxy @endpoint) `hashWithSalt` ps `hashWithSalt` qs + hashWithSalt salt (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = salt `hashWithSalt` (0 :: Int) `hashWithSalt` typeRep (Proxy @endpoint) `hashWithSalt` ps `hashWithSalt` qs + hashWithSalt salt ServantApiBaseRoute = salt `hashWithSalt` (1 :: Int) instance HasRoute (ServantApiUnproxy' proxy) => Read (Route (ServantApi proxy)) where - readPrec = readP_to_Prec $ \d -> do - when (d > 10) . void $ R.char '(' - R.skipSpaces - void $ R.string "ServantApiRoute " - R.skipSpaces - void $ R.string "_ " - R.skipSpaces - asum [ do - void $ R.char '(' - R.skipMany . R.manyTill (R.satisfy $ const True) $ R.char ')' - void $ R.char ' ' - , R.skipMany . R.manyTill (R.satisfy $ not . Char.isSpace) $ R.satisfy Char.isSpace - ] - R.skipSpaces - ps <- readPrec_to_P readPrec 11 - void $ R.char ' ' - R.skipSpaces - qs <- readPrec_to_P readPrec 11 :: R.ReadP (HashMap Text [Text]) - R.skipSpaces - when (d > 10) . void $ R.char ')' - maybe (fail "Could not parse servant route") return $ parseServantRoute (ps, ifoldMap (fmap . (,)) qs) + readPrec = readP_to_Prec $ \d -> asum + [ ServantApiBaseRoute <$ R.string "ServantApiBaseRoute" + , do + when (d > 10) . void $ R.char '(' + R.skipSpaces + void $ R.string "ServantApiRoute " + R.skipSpaces + void $ R.string "_ " + R.skipSpaces + asum [ do + void $ R.char '(' + R.skipMany . R.manyTill (R.satisfy $ const True) $ R.char ')' + void $ R.char ' ' + , R.skipMany . R.manyTill (R.satisfy $ not . Char.isSpace) $ R.satisfy Char.isSpace + ] + R.skipSpaces + ps <- readPrec_to_P readPrec 11 + void $ R.char ' ' + R.skipSpaces + qs <- readPrec_to_P readPrec 11 :: R.ReadP (HashMap Text [Text]) + R.skipSpaces + when (d > 10) . void $ R.char ')' + maybe (fail "Could not parse servant route") return $ parseServantRoute (ps, ifoldMap (fmap . (,)) qs) + ] instance HasRoute (ServantApiUnproxy' proxy) => Show (Route (ServantApi proxy)) where showsPrec d (ServantApiRoute (_ :: Proxy endpoint) _ ps qs) = showParen (d > 10) $ showString "ServantApiRoute " @@ -296,6 +325,7 @@ instance HasRoute (ServantApiUnproxy' proxy) => Show (Route (ServantApi proxy)) . showsPrec 11 ps . showString " " . showsPrec 11 qs + showsPrec _ ServantApiBaseRoute = showString "ServantApiBaseRoute" instance HasRoute (ServantApiUnproxy' proxy) => ParseRoute (ServantApi proxy) where parseRoute = parseServantRoute @@ -385,10 +415,10 @@ servantApiLink _ _ = safeLink' (fromMaybe (error "Could not parse result of safe guardEndpoint _ = Nothing -data ServantHandlerContextFor site = ServantHandlerContextFor - { sctxSite :: site - , sctxRequest :: Request - } +class HasServantHandlerContext site where + data ServantHandlerContextFor site :: Type + getSCtxSite :: ServantHandlerContextFor site -> site + getSCtxRequest :: ServantHandlerContextFor site -> Request newtype ServantHandlerFor site a = ServantHandlerFor { unServantHandlerFor :: ServantHandlerContextFor site -> Handler a } deriving (Generic, Typeable) @@ -404,10 +434,10 @@ getServantContext = liftServantHandler $ ServantHandlerFor return getsServantContext :: (site ~ site', MonadServantHandler site m) => (ServantHandlerContextFor site' -> a) -> m a getsServantContext = liftServantHandler . ServantHandlerFor . (return .) -getYesodApproot :: (Yesod site, MonadServantHandler site m) => m Text -getYesodApproot = getsServantContext $ \ServantHandlerContextFor{..} -> Yesod.getApprootText Yesod.approot sctxSite sctxRequest +getYesodApproot :: (Yesod site, MonadSite site m, MonadRequest m) => m Text +getYesodApproot = Yesod.getApprootText Yesod.approot <$> getSite <*> getRequest -renderRouteAbsolute :: (Yesod site, MonadServantHandler site m) => Route site -> m URI +renderRouteAbsolute :: (Yesod site, MonadSite site m, MonadRequest m) => Route site -> m URI renderRouteAbsolute (renderRoute -> (ps, qs)) = addRoute . unpack <$> getYesodApproot where addRoute root = case parseURI root of Just root' -> root' & uriPathLens . packed %~ addPath @@ -419,13 +449,16 @@ renderRouteAbsolute (renderRoute -> (ps, qs)) = addRoute . unpack <$> getYesodAp addQuery "?" = addQuery "" addQuery q = q <> "&" <> tailEx (addQuery "") -class MonadIO m => MonadServantHandler site m | m -> site where +servantApiBaseUrl :: (Yesod site, MonadSite site m, MonadRequest m, MonadThrow m) => (Route (ServantApi proxy) -> Route site) -> m BaseUrl +servantApiBaseUrl = parseBaseUrl . ($ mempty). uriToString (const "") <=< renderRouteAbsolute . ($ ServantApiBaseRoute) + +class (MonadIO m, HasServantHandlerContext site) => MonadServantHandler site m | m -> site where liftServantHandler :: forall a. ServantHandlerFor site a -> m a -instance MonadServantHandler site (ServantHandlerFor site) where +instance HasServantHandlerContext site => MonadServantHandler site (ServantHandlerFor site) where liftServantHandler = id -instance (MonadTrans t, MonadIO (t (ServantHandlerFor site))) => MonadServantHandler site (t (ServantHandlerFor site)) where +instance (MonadTrans t, MonadIO (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadServantHandler site (t (ServantHandlerFor site)) where liftServantHandler = lift class MonadIO m => MonadHandler m where @@ -443,8 +476,8 @@ class Monad m => MonadSite site m | m -> site where getsSite :: (site -> a) -> m a getsSite f = f <$> getSite -instance MonadSite site (ServantHandlerFor site) where - getSite = liftServantHandler . ServantHandlerFor $ return . sctxSite +instance HasServantHandlerContext site => MonadSite site (ServantHandlerFor site) where + getSite = liftServantHandler . ServantHandlerFor $ return . getSCtxSite instance MonadSite site (Reader site) where getSite = ask @@ -454,10 +487,22 @@ instance {-# OVERLAPPABLE #-} (Yesod.MonadHandler m, site ~ Yesod.HandlerSite m) getSite = Yesod.getYesod getsSite = Yesod.getsYesod -instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site))) => MonadSite site (t (ServantHandlerFor site)) where +instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadSite site (t (ServantHandlerFor site)) where getSite = lift getSite getsSite = lift . getsSite +class Monad m => MonadRequest m where + getRequest :: m Request + +instance HasServantHandlerContext site => MonadRequest (ServantHandlerFor site) where + getRequest = liftServantHandler . ServantHandlerFor $ return . getSCtxRequest + +instance {-# OVERLAPPABLE #-} (Yesod.MonadHandler m, Monad m) => MonadRequest m where + getRequest = Yesod.waiRequest + +instance {-# OVERLAPPING #-} (MonadTrans t, Monad (t (ServantHandlerFor site)), HasServantHandlerContext site) => MonadRequest (t (ServantHandlerFor site)) where + getRequest = lift getRequest + type ServantDBFor site = ReaderT (Yesod.YesodPersistBackend site) (ServantHandlerFor site) @@ -466,6 +511,7 @@ class Yesod.YesodPersist site => ServantPersist site where defaultRunDB :: ( PersistConfig c , ServantDBFor site a ~ PersistConfigBackend c (ServantHandlerFor site) a + , HasServantHandlerContext site ) => Getting c site c -> Getting (PersistConfigPool c) site (PersistConfigPool c) @@ -485,12 +531,12 @@ instance Yesod site => ServantLog (ServantLogYesod site) where logger <- Yesod.makeLogger app Yesod.messageLoggerSource app logger a b c d -instance ServantLog site => MonadLogger (ServantHandlerFor site) where +instance (ServantLog site, HasServantHandlerContext site) => MonadLogger (ServantHandlerFor site) where monadLoggerLog a b c d = do app <- getSite servantLogLog app a b c d -instance ServantLog site => MonadLoggerIO (ServantHandlerFor site) where +instance (ServantLog site, HasServantHandlerContext site) => MonadLoggerIO (ServantHandlerFor site) where askLoggerIO = servantLogLog <$> getSite @@ -501,56 +547,6 @@ instance PathPiece a => FromHttpApiData (PathPieceHttpApiData a) where parseUrlPiece = maybe (Left "Could not convert from HttpApiData via PathPiece") Right . fromPathPiece instance PathPiece a => ToHttpApiData (PathPieceHttpApiData a) where toUrlPiece = toPathPiece - - -data BearerAuth -data SessionAuth - -instance HasSwagger sub => HasSwagger (BearerAuth :> sub) where - toSwagger _ = toSwagger (Proxy @sub) - & securityDefinitions <>~ SecurityDefinitions (fromList [(defnKey, defn)]) - & allOperations . security <>~ [SecurityRequirement $ fromList [(defnKey, [])]] - where defnKey :: Text - defnKey = "bearer" - defn = SecurityScheme - { _securitySchemeType - = SecuritySchemeApiKey ApiKeyParams - { _apiKeyName = "Authorization" - , _apiKeyIn = ApiKeyHeader - } - , _securitySchemeDescription = Just - "JSON Web Token-based API key" - } - -instance HasSwagger sub => HasSwagger (SessionAuth :> sub) where - toSwagger _ = toSwagger (Proxy @sub) - & allOperations . security <>~ [SecurityRequirement mempty] - -- We do not expect API clients to be able/willing to conform with - -- our CSRF mitigation, so we mark routes that require it as - -- having unfullfillable security requirements - -instance HasLink sub => HasLink (BearerAuth :> sub) where - type MkLink (BearerAuth :> sub) a = MkLink sub a - toLink toA _ = toLink toA (Proxy @sub) - -instance HasLink sub => HasLink (SessionAuth :> sub) where - type MkLink (SessionAuth :> sub) a = MkLink sub a - toLink toA _ = toLink toA (Proxy @sub) - -instance HasDocs sub => HasDocs (BearerAuth :> sub) where - docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action') - where action' = action & authInfo %~ (|> authInfo') - authInfo' = DocAuthentication - "" - "A JSON Web Token-based API key" - -instance HasDocs sub => HasDocs (SessionAuth :> sub) where - docsFor _ (endpoint, action) = docsFor (Proxy @sub) (endpoint, action') - where action' = action & authInfo %~ (|> authInfo') - authInfo' = DocAuthentication - "When a web session is used for authorization, CSRF-mitigation measures must be observed." - "An active web session identifying the user as one with sufficient authorization" - mkYesodApi :: Name -> [ResourceTree String] -> DecsQ diff --git a/test/Model/TypesSpec.hs b/test/Model/TypesSpec.hs index 8d2758458..04344823f 100644 --- a/test/Model/TypesSpec.hs +++ b/test/Model/TypesSpec.hs @@ -51,6 +51,8 @@ import qualified Data.SemVer as SemVer import qualified Data.SemVer.Constraint as SemVer (Constraint) import qualified Data.SemVer.Constraint as SemVer.Constraint +import qualified Data.HashSet as HashSet + instance Arbitrary Season where @@ -343,7 +345,9 @@ instance Arbitrary RoomReference' where arbitrary = genericArbitrary instance Arbitrary ExternalApiConfig where - arbitrary = genericArbitrary + arbitrary = oneof + [ EApiGradelistFormat <$> ((fmap HashSet.fromList . scale (`div` 10) $ listOf1 (resize 3 arbitrary)) `suchThatMap` fromNullable) + ] shrink = genericShrink instance Arbitrary SemVer.Version where diff --git a/test/Servant/Client/Core/BaseUrl/TestInstances.hs b/test/Servant/Client/Core/BaseUrl/TestInstances.hs index 07b8a4eb3..86dbe9453 100644 --- a/test/Servant/Client/Core/BaseUrl/TestInstances.hs +++ b/test/Servant/Client/Core/BaseUrl/TestInstances.hs @@ -7,7 +7,14 @@ import Network.URI import Network.URI.Arbitrary () import Servant.Client.Core.BaseUrl +import Control.Lens.Extras + instance Arbitrary BaseUrl where - arbitrary = toBaseUrl <$> arbitrary - where toBaseUrl = either (error . displayException) id . parseBaseUrl . ($ mempty) . uriToString id + arbitrary = (`suchThatMap` toBaseUrl) $ do + uri <- scale (min 10) arbitrary `suchThat` (is _Just . uriAuthority) + uriScheme <- oneof $ map (return . (<> ":")) [ "http", "https" ] + let uriAuthority'' = uriAuthority uri <&> \uriAuthority' -> uriAuthority'{ uriUserInfo = "" } + return (uri, uriScheme, uriAuthority'') + where + toBaseUrl (uri, uriScheme, uriAuthority'') = either (const Nothing) Just . parseBaseUrl . ($ mempty) $ uriToString (const mempty) uri{ uriScheme, uriAuthority = uriAuthority'', uriQuery = "", uriFragment = "" } diff --git a/test/ServantApi/ExternalApis/TypeSpec.hs b/test/ServantApi/ExternalApis/TypeSpec.hs index 466ba5b7e..312aa6cad 100644 --- a/test/ServantApi/ExternalApis/TypeSpec.hs +++ b/test/ServantApi/ExternalApis/TypeSpec.hs @@ -8,7 +8,10 @@ import ServantApi.ExternalApis.Type instance Arbitrary ExternalApiCreationRequest where - arbitrary = genericArbitrary + arbitrary = ExternalApiCreationRequest + <$> scale (`div` 2) arbitrary + <*> scale (`div` 2) arbitrary + <*> scale (`div` 2) arbitrary shrink = genericShrink diff --git a/test/ServantApi/ExternalApisSpec.hs b/test/ServantApi/ExternalApisSpec.hs new file mode 100644 index 000000000..2fba5b343 --- /dev/null +++ b/test/ServantApi/ExternalApisSpec.hs @@ -0,0 +1,48 @@ +{-# OPTIONS_GHC -Wno-error=unused-local-binds #-} + +module ServantApi.ExternalApisSpec where + +import TestImport +import ServantApi.ExternalApis.Type +import ServantApi.ExternalApis.TypeSpec () + +import Servant.Client.Core (RequestF(..)) +import Servant.Client.Generic + +import Utils.Tokens +import Data.Time.Clock (nominalDay) + +import qualified Data.HashSet as HashSet +import qualified Data.HashMap.Strict as HashMap + +import qualified Data.Sequence as Seq + +import Control.Monad.Reader.Class (MonadReader(local)) +import Utils (CustomHeader(..), waiCustomHeader) + + +spec :: Spec +spec = withApp . describe "ExternalApis" $ do + it "Supports dryRun" $ do + adminId <- runDB $ do + Entity adminId _ <- insertEntity $ fakeUser id + ifi <- insert $ School "Institut für Informatik" "IfI" (Just $ 14 * nominalDay) (Just $ 10 * nominalDay) True (ExamModeDNF predDNFFalse) (ExamCloseOnFinished True) SchoolAuthorshipStatementModeOptional Nothing True SchoolAuthorshipStatementModeRequired Nothing False + insert_ $ UserFunction adminId ifi SchoolAdmin + return adminId + + accessToken <- runHandler $ encodeBearer =<< bearerToken (HashSet.singleton $ Right adminId) Nothing HashMap.empty Nothing Nothing Nothing + + let + insertExternalApi = void $ externalApisCreateR accessToken =<< liftIO (generate $ resize 10 arbitrary) + where ExternalApis{..} = genericClient + withDryRun :: ServantExampleEnv -> ServantExampleEnv + withDryRun seEnv = seEnv + { yseMakeClientRequest = \burl req -> yseMakeClientRequest seEnv burl req{ requestHeaders = requestHeaders req Seq.:|> waiCustomHeader HeaderDryRun True } + } + externalApiCount = runDB $ count @_ @_ @ExternalApi [] + + runServantExample ExternalApisR insertExternalApi + liftIO . (`shouldBe` 1) =<< externalApiCount + + runServantExample ExternalApisR $ local withDryRun insertExternalApi + liftIO . (`shouldBe` 1) =<< externalApiCount diff --git a/test/ServantApiSpec.hs b/test/ServantApiSpec.hs new file mode 100644 index 000000000..001e9a7e7 --- /dev/null +++ b/test/ServantApiSpec.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} + +module ServantApiSpec where + +import TestImport +import ServantApi + +import Servant.API +import Servant.API.TypeLevel (MapSub, AppendList) +import Foundation.Servant.Types (ApiVersion) + +import GHC.TypeLits +import Data.Kind (Constraint) + + +type family Unversioned api where + Unversioned (ApiVersion _ _ _ :> _) = '[] + Unversioned (sup :> sub) = MapSub sup (Unversioned sub) + Unversioned (a :<|> b) = AppendList (Unversioned a) (Unversioned b) + Unversioned (Verb method statusCode contentTypes a) = '[Verb method statusCode contentTypes a] + Unversioned (NoContentVerb method) = '[NoContentVerb method] + +type family UnversionedError xs :: ErrorMessage where + UnversionedError (x ': '[]) = 'Text "Unversioned API endpoint: " ':$$: ('Text " " ':<>: 'ShowType x) + UnversionedError (x ': xs) = UnversionedError (x ': '[]) ':$$: UnversionedError xs + +type family IsEmpty xs :: Constraint where + IsEmpty '[] = () + IsEmpty xs = TypeError ('Text "All API endpoints must be versioned." ':$$: UnversionedError xs) + +spec :: Spec +spec = describe "Servant endpoints" $ it "are all versioned" versioned + where + versioned :: IsEmpty (Unversioned UniWorXApi) => Bool + versioned = True diff --git a/test/TestImport.hs b/test/TestImport.hs index be362d41d..ed01b32da 100644 --- a/test/TestImport.hs +++ b/test/TestImport.hs @@ -1,3 +1,5 @@ +{-# OPTIONS_GHC -fno-warn-deprecations #-} + module TestImport ( module TestImport , module X @@ -44,6 +46,34 @@ import Jobs (handleJobs) import Numeric.Natural as X import Network.URI.Arbitrary as X () +import qualified Network.Wai as Wai +import qualified Network.Wai.Test as Wai +import qualified Network.Wai.Test.Internal as Wai (ClientState) +import Network.HTTP.Types (Status(..), hContentType, hAccept) +import Network.HTTP.Types.Header (hHost) +import qualified Network.HTTP.Types as Wai + +import Control.Monad.Trans.Except (ExceptT) +import qualified Servant.Client.Core as Servant +import Servant.Client.Core.ClientError +import Servant.Client.Core.RunClient +import Control.Monad.Except (MonadError(..)) +import Control.Monad.State.Class (MonadState(..)) +import qualified Control.Monad.State.Class as State +import qualified Servant.Types.SourceT as S +import Servant.API (SourceIO) + +import Utils (throwExceptT) + +import Yesod.Servant (ServantApi, servantApiBaseUrl) + +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as Lazy (ByteString) +import qualified Data.ByteString.Lazy as LBS hiding (ByteString) +import qualified Data.Binary.Builder as B +import Network.HTTP.Media (renderHeader) +import Control.Monad.Fail + import Control.Lens as X hiding ((<.), elements) import Network.IP.Addr as X (IP) @@ -133,3 +163,105 @@ lawsCheckHspec p = parallel . describe (show $ typeRep p) . mapM_ (checkHspec . where checkHspec (Laws className properties) = describe className $ forM_ properties $ \(name, prop) -> it name $ property prop + + +newtype ServantExample a = ServantExample + { unServantExample :: ReaderT ServantExampleEnv (ExceptT ClientError Wai.Session) a + } deriving stock (Generic, Typeable) + deriving newtype (Functor, Applicative, Monad, MonadIO, MonadReader ServantExampleEnv, MonadError ClientError, MonadThrow, MonadCatch, MonadState Wai.ClientState) + +data ServantExampleEnv = ServantExampleEnv + { yseBaseUrl :: BaseUrl + , yseMakeClientRequest :: BaseUrl -> Servant.Request -> IO Wai.Request + } deriving (Generic, Typeable) + +runServantExample :: (Route (ServantApi proxy) -> Route UniWorX) -> ServantExample a -> YesodExample UniWorX a +runServantExample apiR (ServantExample act) = do + yseBaseUrl <- runHandler $ servantApiBaseUrl apiR + let yseMakeClientRequest burl Servant.Request{..} = do + ((body, bodyLength), contentTypeHdr) <- case requestBody of + Nothing -> return ((return BS.empty, Wai.KnownLength 0), Nothing) + Just (body', typ) -> let (mkBody, bLength) = convertBody body' + in (, Just (hContentType, renderHeader typ)) . (, bLength) <$> mkBody + + return $ Wai.defaultRequest + { Wai.requestMethod = requestMethod + , Wai.requestHeaders = maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers + , Wai.requestHeaderHost = + let BaseUrl{..} = yseBaseUrl + in Just . encodeUtf8 . pack $ baseUrlHost <> bool (":" <> show baseUrlPort) mempty (baseUrlPort == 80) + , Wai.requestBody = body, Wai.requestBodyLength = bodyLength + , Wai.isSecure = isSecure + } + & flip Wai.setPath (encodeUtf8 (pack $ baseUrlPath burl) <> toStrict (B.toLazyByteString requestPath) <> Wai.renderQuery True (toList requestQueryString)) + where + headers = filter (\(h, _) -> h `notElem` [hAccept, hContentType, hHost]) $ toList requestHeaders + + acceptHdr + | null hs = Nothing + | otherwise = Just (hAccept, renderHeader hs) + where + hs = toList requestAccept + + convertBody :: Servant.RequestBody -> (IO (IO ByteString), Wai.RequestBodyLength) + convertBody bd = case bd of + Servant.RequestBodyLBS body' -> ( givesPopper . S.source . map fromStrict $ LBS.toChunks body' + , Wai.KnownLength . fromIntegral $ LBS.length body' + ) + Servant.RequestBodyBS body' -> ( return $ return body' + , Wai.KnownLength . fromIntegral $ BS.length body' + ) + Servant.RequestBodySource sourceIO -> ( givesPopper sourceIO + , Wai.ChunkedBody + ) + where + givesPopper :: SourceIO Lazy.ByteString -> IO (IO ByteString) + givesPopper sourceIO = S.unSourceT sourceIO $ \step0 -> do + ref <- newMVar step0 + return $ modifyMVar ref nextBs + + nextBs S.Stop = return (S.Stop, BS.empty) + nextBs (S.Error err) = fail err + nextBs (S.Skip s) = nextBs s + nextBs (S.Effect ms) = ms >>= nextBs + nextBs (S.Yield lbs s) = case LBS.toChunks lbs of + [] -> nextBs s + (x:xs) | BS.null x -> nextBs step' + | otherwise -> return (step', x) + where + step' = S.Yield (LBS.fromChunks xs) s + + isSecure = case baseUrlScheme burl of + Servant.Http -> False + Servant.Https -> True + YesodExampleData waiApp _ _ _ <- State.get + liftIO . flip Wai.runSession waiApp . throwExceptT $ runReaderT act ServantExampleEnv{..} + +instance RunClient ServantExample where + runRequestAcceptStatus acceptStatus req = do + ServantExampleEnv{..} <- ask + waiRequest <- liftIO $ yseMakeClientRequest yseBaseUrl req + waiResponse@Wai.SResponse{..} <- ServantExample . lift . lift $ Wai.request waiRequest + let Status{..} = simpleStatus + statusOk = case acceptStatus of + Nothing -> 200 <= statusCode && statusCode < 300 + Just good -> simpleStatus `elem` good + response = (waiResponseToResponse waiResponse) { Servant.responseHttpVersion = Wai.httpVersion waiRequest } + unless statusOk $ + throwError $ mkFailureResponse yseBaseUrl req response + return response + where + mkFailureResponse :: BaseUrl -> Servant.Request -> Servant.ResponseF Lazy.ByteString -> ClientError + mkFailureResponse burl request' = + FailureResponse (bimap (const ()) f request') + where + f b = (burl, LBS.toStrict $ B.toLazyByteString b) + + waiResponseToResponse :: Wai.SResponse -> Servant.Response + waiResponseToResponse Wai.SResponse{..} = Servant.Response + { responseStatusCode = simpleStatus + , responseBody = simpleBody + , responseHeaders = fromList simpleHeaders + , responseHttpVersion = error "WAI Response does not carry http version information" + } + throwClientError = throwError