Compare commits

...

3 Commits

3 changed files with 58 additions and 36 deletions

View File

@ -12,7 +12,7 @@ import Prelude.Compat
import Text.PrettyPrint import Text.PrettyPrint
data PredicateFailure data PredicateFailure
= PredicateFailure T.Text (Maybe C.Request) (C.Response LBS.ByteString) = PredicateFailure T.Text (C.Request) (C.Response LBS.ByteString)
deriving (Typeable, Generic) deriving (Typeable, Generic)
instance Exception ServerEqualityFailure where instance Exception ServerEqualityFailure where
@ -71,10 +71,5 @@ prettyPredicateFailure :: PredicateFailure -> Doc
prettyPredicateFailure (PredicateFailure predicate req resp) = prettyPredicateFailure (PredicateFailure predicate req resp) =
text "Predicate failed" $$ (nest 5 $ text "Predicate failed" $$ (nest 5 $
text "Predicate:" <+> (text $ T.unpack predicate) text "Predicate:" <+> (text $ T.unpack predicate)
$$ r $$ prettyReq req
$$ prettyResp resp) $$ prettyResp resp)
where
r = case req of
Nothing -> text ""
Just v -> prettyReq v

View File

@ -42,8 +42,9 @@ import Servant.QuickCheck.Internal.ErrorTypes
-- --
-- /Since 0.0.0.0/ -- /Since 0.0.0.0/
not500 :: ResponsePredicate not500 :: ResponsePredicate
not500 = ResponsePredicate $ \resp -> not500 = ResponsePredicate $ \req resp ->
when (responseStatus resp == status500) $ throw $ PredicateFailure "not500" Nothing resp when (responseStatus resp == status500) $
throw $ PredicateFailure "not500" req resp
-- | [__Optional__] -- | [__Optional__]
-- --
@ -58,7 +59,7 @@ notLongerThan maxAllowed
resp <- httpLbs req mgr resp <- httpLbs req mgr
end <- getTime Monotonic end <- getTime Monotonic
when (toNanoSecs (end `diffTimeSpec` start) > maxAllowed) $ when (toNanoSecs (end `diffTimeSpec` start) > maxAllowed) $
throw $ PredicateFailure "notLongerThan" (Just req) resp throw $ PredicateFailure "notLongerThan" req resp
return [] return []
-- | [__Best Practice__] -- | [__Best Practice__]
@ -84,8 +85,8 @@ notLongerThan maxAllowed
-- /Since 0.0.0.0/ -- /Since 0.0.0.0/
onlyJsonObjects :: ResponsePredicate onlyJsonObjects :: ResponsePredicate
onlyJsonObjects onlyJsonObjects
= ResponsePredicate (\resp -> case go resp of = ResponsePredicate (\req resp -> case go resp of
Nothing -> throw $ PredicateFailure "onlyJsonObjects" Nothing resp Nothing -> throw $ PredicateFailure "onlyJsonObjects" req resp
Just () -> return ()) Just () -> return ())
where where
go r = do go r = do
@ -120,12 +121,12 @@ createContainsValidLocation
resp <- httpLbs req mgr resp <- httpLbs req mgr
if responseStatus resp == status201 if responseStatus resp == status201
then case lookup "Location" $ responseHeaders resp of then case lookup "Location" $ responseHeaders resp of
Nothing -> throw $ PredicateFailure n (Just req) resp Nothing -> throw $ PredicateFailure n req resp
Just l -> case parseRequest $ SBSC.unpack l of Just l -> case parseRequest $ SBSC.unpack l of
Nothing -> throw $ PredicateFailure n (Just req) resp Nothing -> throw $ PredicateFailure n req resp
Just x -> do Just x -> do
resp2 <- httpLbs x mgr resp2 <- httpLbs x mgr
status2XX (Just req) resp2 n status2XX req resp2 n
return [resp, resp2] return [resp, resp2]
else return [resp] else return [resp]
@ -160,8 +161,8 @@ getsHaveLastModifiedHeader
if (method req == methodGet) if (method req == methodGet)
then do then do
resp <- httpLbs req mgr resp <- httpLbs req mgr
unless (hasValidHeader "Last-Modified" isRFC822Date resp) $ do unless (hasValidHeader "Last-Modified" isRFC822Date resp) $
throw $ PredicateFailure "getsHaveLastModifiedHeader" (Just req) resp throw $ PredicateFailure "getsHaveLastModifiedHeader" req resp
return [resp] return [resp]
else return [] else return []
@ -193,7 +194,7 @@ notAllowedContainsAllowHeader
| m <- [minBound .. maxBound ] | m <- [minBound .. maxBound ]
, renderStdMethod m /= method req ] , renderStdMethod m /= method req ]
case filter pred' resp of case filter pred' resp of
(x:_) -> throw $ PredicateFailure "notAllowedContainsAllowHeader" (Just req) x (x:_) -> throw $ PredicateFailure "notAllowedContainsAllowHeader" req x
[] -> return resp [] -> return resp
where where
pred' resp = responseStatus resp == status405 && not (hasValidHeader "Allow" go resp) pred' resp = responseStatus resp == status405 && not (hasValidHeader "Allow" go resp)
@ -226,7 +227,7 @@ honoursAcceptHeader
sacc = fromMaybe "*/*" $ lookup "Accept" (requestHeaders req) sacc = fromMaybe "*/*" $ lookup "Accept" (requestHeaders req)
if status100 < scode && scode < status300 if status100 < scode && scode < status300
then if isJust $ sctype >>= \x -> matchAccept [x] sacc then if isJust $ sctype >>= \x -> matchAccept [x] sacc
then throw $ PredicateFailure "honoursAcceptHeader" (Just req) resp then throw $ PredicateFailure "honoursAcceptHeader" req resp
else return [resp] else return [resp]
else return [resp] else return [resp]
@ -251,8 +252,8 @@ getsHaveCacheControlHeader
if (method req == methodGet) if (method req == methodGet)
then do then do
resp <- httpLbs req mgr resp <- httpLbs req mgr
unless (hasValidHeader "Cache-Control" (const True) resp) $ do unless (hasValidHeader "Cache-Control" (const True) resp) $
throw $ PredicateFailure "getsHaveCacheControlHeader" (Just req) resp throw $ PredicateFailure "getsHaveCacheControlHeader" req resp
return [resp] return [resp]
else return [] else return []
@ -268,7 +269,7 @@ headsHaveCacheControlHeader
then do then do
resp <- httpLbs req mgr resp <- httpLbs req mgr
unless (hasValidHeader "Cache-Control" (const True) resp) $ unless (hasValidHeader "Cache-Control" (const True) resp) $
throw $ PredicateFailure "headsHaveCacheControlHeader" (Just req) resp throw $ PredicateFailure "headsHaveCacheControlHeader" req resp
return [resp] return [resp]
else return [] else return []
{- {-
@ -334,10 +335,10 @@ linkHeadersAreValid
-- /Since 0.0.0.0/ -- /Since 0.0.0.0/
unauthorizedContainsWWWAuthenticate :: ResponsePredicate unauthorizedContainsWWWAuthenticate :: ResponsePredicate
unauthorizedContainsWWWAuthenticate unauthorizedContainsWWWAuthenticate
= ResponsePredicate $ \resp -> = ResponsePredicate $ \req resp ->
if responseStatus resp == status401 if responseStatus resp == status401
then unless (hasValidHeader "WWW-Authenticate" (const True) resp) $ then unless (hasValidHeader "WWW-Authenticate" (const True) resp) $
throw $ PredicateFailure "unauthorizedContainsWWWAuthenticate" Nothing resp throw $ PredicateFailure "unauthorizedContainsWWWAuthenticate" req resp
else return () else return ()
@ -354,12 +355,12 @@ unauthorizedContainsWWWAuthenticate
-- /Since 0.3.0.0/ -- /Since 0.3.0.0/
htmlIncludesDoctype :: ResponsePredicate htmlIncludesDoctype :: ResponsePredicate
htmlIncludesDoctype htmlIncludesDoctype
= ResponsePredicate $ \resp -> = ResponsePredicate $ \req resp ->
if hasValidHeader "Content-Type" (SBS.isPrefixOf . foldCase $ "text/html") resp if hasValidHeader "Content-Type" (SBS.isPrefixOf . foldCase $ "text/html") resp
then do then do
let htmlContent = foldCase . LBS.take 20 $ responseBody resp let htmlContent = foldCase . LBS.take 20 $ responseBody resp
unless (LBS.isPrefixOf (foldCase "<!doctype html>") htmlContent) $ unless (LBS.isPrefixOf (foldCase "<!doctype html>") htmlContent) $
throw $ PredicateFailure "htmlIncludesDoctype" Nothing resp throw $ PredicateFailure "htmlIncludesDoctype" req resp
else return () else return ()
-- * Predicate logic -- * Predicate logic
@ -374,12 +375,12 @@ htmlIncludesDoctype
-- --
-- /Since 0.0.0.0/ -- /Since 0.0.0.0/
newtype ResponsePredicate = ResponsePredicate newtype ResponsePredicate = ResponsePredicate
{ getResponsePredicate :: Response LBS.ByteString -> IO () { getResponsePredicate :: Request -> Response LBS.ByteString -> IO ()
} deriving (Generic) } deriving (Generic)
instance Monoid ResponsePredicate where instance Monoid ResponsePredicate where
mempty = ResponsePredicate $ const $ return () mempty = ResponsePredicate (\req resp -> return ())
ResponsePredicate a `mappend` ResponsePredicate b = ResponsePredicate $ \x -> a x >> b x ResponsePredicate a `mappend` ResponsePredicate b = ResponsePredicate $ \x y -> a x y >> b x y
-- | A predicate that depends on both the request and the response. -- | A predicate that depends on both the request and the response.
-- --
@ -429,7 +430,8 @@ finishPredicates p req mgr = go `catch` \(e :: PredicateFailure) -> return $ Jus
where where
go = do go = do
resps <- getRequestPredicate (requestPredicates p) req mgr resps <- getRequestPredicate (requestPredicates p) req mgr
mapM_ (getResponsePredicate $ responsePredicates p) resps let responder = getResponsePredicate (responsePredicates p) req
mapM_ responder resps
return Nothing return Nothing
-- * helpers -- * helpers
@ -445,8 +447,8 @@ isRFC822Date s
Nothing -> False Nothing -> False
Just (_ :: UTCTime) -> True Just (_ :: UTCTime) -> True
status2XX :: Monad m => Maybe Request -> Response LBS.ByteString -> T.Text -> m () status2XX :: Monad m => Request -> Response LBS.ByteString -> T.Text -> m ()
status2XX mreq resp t status2XX req resp t
| status200 <= responseStatus resp && responseStatus resp < status300 | status200 <= responseStatus resp && responseStatus resp < status300
= return () = return ()
| otherwise = throw $ PredicateFailure t mreq resp | otherwise = throw $ PredicateFailure t req resp

View File

@ -46,11 +46,13 @@ spec = do
serversEqualSpec serversEqualSpec
serverSatisfiesSpec serverSatisfiesSpec
isComprehensiveSpec isComprehensiveSpec
no500s
onlyJsonObjectSpec onlyJsonObjectSpec
notLongerThanSpec notLongerThanSpec
queryParamsSpec queryParamsSpec
queryFlagsSpec queryFlagsSpec
deepPathSpec deepPathSpec
authServerCheck
htmlDocTypesSpec htmlDocTypesSpec
unbiasedGenerationSpec unbiasedGenerationSpec
@ -127,6 +129,15 @@ serverSatisfiesSpec = describe "serverSatisfies" $ do
show err `shouldContain` "Body" show err `shouldContain` "Body"
no500s :: Spec
no500s = describe "no500s" $ do
it "fails correctly" $ do
FailedWith err <- withServantServerAndContext api2 ctx server500fail $ \burl -> do
evalExample $ serverSatisfies api2 burl args
(not500 <%> mempty)
show err `shouldContain` "not500"
onlyJsonObjectSpec :: Spec onlyJsonObjectSpec :: Spec
onlyJsonObjectSpec = describe "onlyJsonObjects" $ do onlyJsonObjectSpec = describe "onlyJsonObjects" $ do
@ -193,6 +204,17 @@ queryFlagsSpec = describe "QueryFlags" $ do
qs = C.unpack $ queryString req qs = C.unpack $ queryString req
qs `shouldBe` "one&two" qs `shouldBe` "one&two"
authServerCheck :: Spec
authServerCheck = describe "authenticate endpoints" $ do
it "authorization failure without WWWAuthenticate header fails correctly" $ do
FailedWith err <- withServantServerAndContext api2 ctx authFailServer $ \burl -> do
evalExample $ serverSatisfies api2 burl args
(unauthorizedContainsWWWAuthenticate <%> mempty)
show err `shouldContain` "unauthorizedContainsWWWAuthenticate"
-- Large API Randomness Testing Helper
htmlDocTypesSpec :: Spec htmlDocTypesSpec :: Spec
htmlDocTypesSpec = describe "HtmlDocTypes" $ do htmlDocTypesSpec = describe "HtmlDocTypes" $ do
@ -217,7 +239,6 @@ makeRandomRequest large burl = do
req <- generate $ runGenRequest large req <- generate $ runGenRequest large
pure $ fst . fromJust . C.readInteger . C.drop 1 . path $ req burl pure $ fst . fromJust . C.readInteger . C.drop 1 . path $ req burl
unbiasedGenerationSpec :: Spec unbiasedGenerationSpec :: Spec
unbiasedGenerationSpec = describe "Unbiased Generation of requests" $ unbiasedGenerationSpec = describe "Unbiased Generation of requests" $
@ -274,13 +295,18 @@ type DeepAPI = "one" :> "two" :> "three":> Get '[JSON] ()
deepAPI :: Proxy DeepAPI deepAPI :: Proxy DeepAPI
deepAPI = Proxy deepAPI = Proxy
server2 :: IO (Server API2) server2 :: IO (Server API2)
server2 = return $ return 1 server2 = return $ return 1
server3 :: IO (Server API2) server3 :: IO (Server API2)
server3 = return $ return 2 server3 = return $ return 2
server500fail :: IO (Server API2)
server500fail = return $ throwError $ err500 { errBody = "BOOM!" }
authFailServer :: IO (Server API2)
authFailServer = return $ throwError $ err401 { errBody = "Login failure but missing header"}
-- With Doctypes -- With Doctypes
type HtmlDoctype = Get '[HTML] Blaze.Html type HtmlDoctype = Get '[HTML] Blaze.Html
@ -293,7 +319,6 @@ docTypeServer = pure $ pure $ Blaze5.docTypeHtml $ Blaze5.span "Hello Test!"
noDocTypeServer :: IO (Server HtmlDoctype) noDocTypeServer :: IO (Server HtmlDoctype)
noDocTypeServer = pure $ pure $ Blaze.text "Hello Test!" noDocTypeServer = pure $ pure $ Blaze.text "Hello Test!"
-- Api for unbiased generation of requests tests -- Api for unbiased generation of requests tests
largeApi :: Proxy LargeAPI largeApi :: Proxy LargeAPI
largeApi = Proxy largeApi = Proxy