From 002775e19296bd75cf016146e594df2f6101948b Mon Sep 17 00:00:00 2001 From: Gregor Kleen Date: Fri, 22 May 2020 11:29:30 +0200 Subject: [PATCH] feat(dry-run): implement dry-run BREAKING CHANGE: runDBRead --- load.sh | 10 ++ package.yaml | 12 ++ src/Foundation.hs | 158 +++++++++++++++++++++++---- src/Handler/Profile.hs | 1 - src/Handler/Users.hs | 1 - src/Handler/Utils/Invitations.hs | 2 - src/Handler/Utils/Tokens.hs | 47 -------- src/Utils.hs | 5 +- src/Utils/Parameters.hs | 3 +- stack.yaml | 2 + stack.yaml.lock | 7 ++ test/Load.hs | 181 +++++++++++++++++++++++++++++++ 12 files changed, 357 insertions(+), 72 deletions(-) create mode 100755 load.sh delete mode 100644 src/Handler/Utils/Tokens.hs create mode 100644 test/Load.hs diff --git a/load.sh b/load.sh new file mode 100755 index 000000000..c66108b5e --- /dev/null +++ b/load.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Options: see /test/Load.hs (Main) + +set -e + +[ "${FLOCKER}" != "$0" ] && exec env FLOCKER="$0" flock -en .stack-work.lock "$0" "$@" || : + +stack build --fast --flag uniworx:-library-only --flag uniworx:dev + +stack exec uniworxload -- $@ diff --git a/package.yaml b/package.yaml index 3d35d0ade..458f1bd78 100644 --- a/package.yaml +++ b/package.yaml @@ -266,6 +266,18 @@ executables: when: - condition: flag(library-only) buildable: false + uniworxload: + main: Load.hs + ghc-options: + - -main-is Load + source-dirs: test + dependencies: + - uniworx + - normaldistribution + other-modules: [] + when: + - condition: flag(library-only) + buildable: false # Test suite tests: diff --git a/src/Foundation.hs b/src/Foundation.hs index 299d1d9aa..1b013618f 100644 --- a/src/Foundation.hs +++ b/src/Foundation.hs @@ -2,7 +2,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE OverloadedLabels #-} -{-# OPTIONS_GHC -fno-warn-orphans -fno-warn-incomplete-uni-patterns #-} -- MonadCrypto +{-# OPTIONS_GHC -fno-warn-orphans -fno-warn-incomplete-uni-patterns -fno-warn-redundant-constraints #-} -- MonadCrypto module Foundation ( module Foundation @@ -64,6 +64,7 @@ import Control.Monad.Except (MonadError(..)) import Control.Monad.Trans.State (execStateT) import Control.Monad.Writer.Class (MonadWriter(..)) import Control.Monad.Memo.Class (MonadMemo(..), for4) +import Control.Monad.Reader.Class (MonadReader(local)) import qualified Control.Monad.Catch as C import Handler.Utils.StudyFeatures @@ -103,6 +104,9 @@ import qualified Web.ServerSession.Frontend.Yesod.Jwt as JwtSession import Web.Cookie +import Yesod.Core.Types (GHState(..), HandlerData(handlerState, handlerEnv), RunHandlerEnv(rheSite, rheChild)) +import Database.Persist.Sql (transactionUndo, SqlReadBackend(..)) + -- | Convenient Type Synonyms: type DB = YesodDB UniWorX type Form x = Html -> MForm (HandlerFor UniWorX) (FormResult x, Widget) @@ -256,7 +260,7 @@ instance Exception InvalidAuthTag data AccessPredicate = APPure (Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> Reader MsgRenderer AuthResult) | APHandler (Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> Handler AuthResult) - | APDB (Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> DB AuthResult) + | APDB (Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> ReaderT SqlReadBackend Handler AuthResult) class (MonadHandler m, HandlerSite m ~ UniWorX) => MonadAP m where evalAccessPred :: AccessPredicate -> Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> m AuthResult @@ -265,10 +269,10 @@ instance {-# INCOHERENT #-} (MonadHandler m, HandlerSite m ~ UniWorX) => MonadAP evalAccessPred aPred aid r w = liftHandler $ case aPred of (APPure p) -> runReader (p aid r w) <$> getMsgRenderer (APHandler p) -> p aid r w - (APDB p) -> runDB $ p aid r w + (APDB p) -> runDBRead $ p aid r w -instance (MonadHandler m, HandlerSite m ~ UniWorX, backend ~ YesodPersistBackend UniWorX) => MonadAP (ReaderT backend m) where - evalAccessPred aPred aid r w = mapReaderT liftHandler $ case aPred of +instance (MonadHandler m, HandlerSite m ~ UniWorX, BackendCompatible SqlBackend backend) => MonadAP (ReaderT backend m) where + evalAccessPred aPred aid r w = mapReaderT liftHandler . withReaderT (SqlReadBackend . projectBackend) $ case aPred of (APPure p) -> lift $ runReader (p aid r w) <$> getMsgRenderer (APHandler p) -> lift $ p aid r w (APDB p) -> p aid r w @@ -333,7 +337,7 @@ askBearerUnsafe :: forall m. => ExceptT AuthResult m (BearerToken UniWorX) -- | This performs /no/ meaningful validation of the `BearerToken` -- --- Use `Handler.Utils.Tokens.requireBearerToken` or `Handler.Utils.Tokens.maybeBearerToken` instead +-- Use `requireBearerToken` or `maybeBearerToken` instead askBearerUnsafe = $cachedHere $ do bearer <- maybeMExceptT (unauthorizedI MsgUnauthorizedNoToken) askBearer catch (decodeBearer bearer) $ \case @@ -343,10 +347,10 @@ askBearerUnsafe = $cachedHere $ do $logWarnS "AuthToken" $ tshow other throwError =<< unauthorizedI MsgUnauthorizedTokenInvalid -validateBearer :: Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> BearerToken UniWorX -> DB AuthResult +validateBearer :: Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> BearerToken UniWorX -> ReaderT SqlReadBackend Handler AuthResult validateBearer mAuthId' route' isWrite' token' = $runCachedMemoT $ for4 memo validateBearer' mAuthId' route' isWrite' token' where - validateBearer' :: _ -> _ -> _ -> _ -> CachedMemoT (Maybe (AuthId UniWorX), Route UniWorX, Bool, BearerToken UniWorX) AuthResult DB AuthResult + validateBearer' :: _ -> _ -> _ -> _ -> CachedMemoT (Maybe (AuthId UniWorX), Route UniWorX, Bool, BearerToken UniWorX) AuthResult (ReaderT SqlReadBackend Handler) AuthResult validateBearer' mAuthId route isWrite BearerToken{..} = lift . exceptT return return $ do guardMExceptT (maybe True (HashSet.member route) bearerRoutes) (unauthorizedI MsgUnauthorizedTokenInvalidRoute) @@ -381,6 +385,79 @@ validateBearer mAuthId' route' isWrite' token' = $runCachedMemoT $ for4 memo val return Authorized +maybeBearerToken :: (MonadHandler m, HandlerSite m ~ UniWorX, MonadCatch m) => m (Maybe (BearerToken UniWorX)) +maybeBearerToken = runMaybeT $ catchIfMaybeT cPred requireBearerToken + where + cPred err = any ($ err) + [ is $ _HCError . _PermissionDenied + , is $ _HCError . _NotAuthenticated + ] + +requireBearerToken :: (MonadHandler m, HandlerSite m ~ UniWorX) => m (BearerToken UniWorX) +requireBearerToken = liftHandler $ do + bearer <- exceptT (guardAuthResult >=> error "askToken should not throw `Authorized`") return askBearerUnsafe + mAuthId <- maybeAuthId + currentRoute <- maybe (permissionDeniedI MsgUnauthorizedToken404) return =<< getCurrentRoute + isWrite <- isWriteRequest currentRoute + guardAuthResult <=< runDBRead $ validateBearer mAuthId currentRoute isWrite bearer + return bearer + +requireCurrentBearerRestrictions :: ( MonadHandler m + , HandlerSite m ~ UniWorX + , FromJSON a + , ToJSON a + ) + => m (Maybe a) +requireCurrentBearerRestrictions = runMaybeT $ do + bearer <- requireBearerToken + route <- MaybeT getCurrentRoute + hoistMaybe $ bearer ^? _bearerRestrictionIx route + +maybeCurrentBearerRestrictions :: ( MonadHandler m + , HandlerSite m ~ UniWorX + , MonadCatch m + , FromJSON a + , ToJSON a + ) + => m (Maybe a) +maybeCurrentBearerRestrictions = runMaybeT $ do + bearer <- MaybeT maybeBearerToken + route <- MaybeT getCurrentRoute + hoistMaybe $ bearer ^? _bearerRestrictionIx route + +isDryRun :: forall m. + ( MonadHandler m + , HandlerSite m ~ UniWorX + , MonadCatch m + ) + => m Bool +isDryRun = $cachedHere $ orM + [ hasGlobalPostParam PostDryRun + , hasGlobalGetParam GetDryRun + , and2M bearerDryRun bearerRequired + ] + where + bearerDryRun = maybeT (return False) $ MaybeT maybeCurrentBearerRestrictions >>= hoistMaybe . \case + JSON.Object hm -> Just $ HashMap.member "dry-run" hm + _other -> Nothing + bearerRequired = maybeT (return True) . catchIfMaybeT cPred . liftHandler $ do + mAuthId <- maybeAuthId + currentRoute <- maybe (permissionDeniedI MsgUnauthorizedToken404) return =<< getCurrentRoute + isWrite <- isWriteRequest currentRoute + + let noTokenAuth :: AuthDNF -> AuthDNF + noTokenAuth = over _dnfTerms . Set.filter . noneOf (re _nullable . folded) $ (== AuthToken) . plVar + + dnf <- either throwM return $ routeAuthTags currentRoute + guardAuthResult <=< fmap fst . runWriterT $ evalAuthTags (AuthTagActive $ const True) (noTokenAuth dnf) mAuthId currentRoute isWrite + + return False + + cPred err = any ($ err) + [ is $ _HCError . _PermissionDenied + , is $ _HCError . _NotAuthenticated + ] + tagAccessPredicate :: AuthTag -> AccessPredicate tagAccessPredicate AuthFree = trueAP @@ -1096,7 +1173,7 @@ tagAccessPredicate AuthParticipant = APDB $ \mAuthId route _ -> case route of where isCourseParticipant tid ssh csh participant onlyActive = do let - authorizedIfExists :: E.From a => (a -> E.SqlQuery b) -> ExceptT AuthResult DB () + authorizedIfExists :: E.From a => (a -> E.SqlQuery b) -> ExceptT AuthResult (ReaderT SqlReadBackend Handler) () authorizedIfExists = flip whenExceptT Authorized <=< lift . E.selectExists . E.from -- participant is currently registered mapExceptT ($cachedHereBinary (participant, tid, ssh, csh)) . authorizedIfExists $ \(course `E.InnerJoin` courseParticipant) -> do @@ -1357,8 +1434,12 @@ tagAccessPredicate AuthAuthentication = APDB $ \mAuthId route _ -> case route of guard $ not systemMessageAuthenticatedOnly || isAuthenticated return Authorized r -> $unsupportedAuthPredicate AuthAuthentication r -tagAccessPredicate AuthRead = APHandler . const . const $ bool (return Authorized) (unauthorizedI MsgUnauthorizedWrite) -tagAccessPredicate AuthWrite = APHandler . const . const $ bool (unauthorizedI MsgUnauthorized) (return Authorized) +tagAccessPredicate AuthRead = APPure $ \_ _ isWrite -> do + MsgRenderer mr <- ask + return $ bool Authorized (Unauthorized $ mr MsgUnauthorizedWrite) isWrite +tagAccessPredicate AuthWrite = APPure $ \_ _ isWrite -> do + MsgRenderer mr <- ask + return $ bool (Unauthorized $ mr MsgUnauthorized) Authorized isWrite authTagSpecificity :: AuthTag -> AuthTag -> Ordering @@ -1430,7 +1511,7 @@ evalAuthTags AuthTagActive{..} (map (Set.toList . toNullable) . Set.toList . dnf evalDNF :: [[AuthLiteral]] -> WriterT (Set AuthTag) m AuthResult evalDNF = foldr (\ats ar -> ar `orAR'` foldr (\aTag ar' -> ar' `andAR'` evalAuthLiteral aTag) (return $ trueAR mr) ats) (return $ falseAR mr) - $logDebugS "evalAuthTags" . tshow . (route, isWrite, )$ map (map $ id &&& authTagIsActive . plVar) authDNF + $logDebugS "evalAuthTags" . tshow . (route, isWrite, ) $ map (map $ id &&& authTagIsActive . plVar) authDNF result <- evalDNF $ filter (all $ authTagIsActive . plVar) authDNF @@ -1449,7 +1530,7 @@ evalAccessFor mAuthId route isWrite = do dnf <- either throwM return $ routeAuthTags route fmap fst . runWriterT $ evalAuthTags (AuthTagActive $ const True) dnf mAuthId route isWrite -evalAccessForDB :: (MonadThrow m, MonadHandler m, HandlerSite m ~ UniWorX) => Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> ReaderT (YesodPersistBackend UniWorX) m AuthResult +evalAccessForDB :: (MonadThrow m, MonadHandler m, HandlerSite m ~ UniWorX, BackendCompatible SqlBackend backend) => Maybe (AuthId UniWorX) -> Route UniWorX -> Bool -> ReaderT backend m AuthResult evalAccessForDB = evalAccessFor evalAccess :: (MonadThrow m, MonadHandler m, HandlerSite m ~ UniWorX) => Route UniWorX -> Bool -> m AuthResult @@ -1460,7 +1541,7 @@ evalAccess route isWrite = do (result, deactivated) <- runWriterT $ evalAuthTags tagActive dnf mAuthId route isWrite result <$ tellSessionJson SessionInactiveAuthTags deactivated -evalAccessDB :: (MonadThrow m, MonadHandler m, HandlerSite m ~ UniWorX) => Route UniWorX -> Bool -> ReaderT (YesodPersistBackend UniWorX) m AuthResult +evalAccessDB :: (MonadThrow m, MonadHandler m, HandlerSite m ~ UniWorX, BackendCompatible SqlBackend backend) => Route UniWorX -> Bool -> ReaderT backend m AuthResult evalAccessDB = evalAccess -- | Check whether the current user is authorized by `evalAccess` for the given route @@ -1588,8 +1669,28 @@ instance Yesod UniWorX where -- b) Validates that incoming write requests include that token in either a header or POST parameter. -- To add it, chain it together with the defaultMiddleware: yesodMiddleware = defaultYesodMiddleware . defaultCsrfMiddleware -- For details, see the CSRF documentation in the Yesod.Core.Handler module of the yesod-core package. - yesodMiddleware = observeYesodCacheSizeMiddleware . languagesMiddleware appLanguages . headerMessagesMiddleware . defaultYesodMiddleware . normalizeRouteMiddleware . csrfMiddleware . updateFavouritesMiddleware . storeBearerMiddleware + yesodMiddleware = storeBearerMiddleware . csrfMiddleware . dryRunMiddleware . observeYesodCacheSizeMiddleware . languagesMiddleware appLanguages . headerMessagesMiddleware . defaultYesodMiddleware . normalizeRouteMiddleware . updateFavouritesMiddleware where + dryRunMiddleware :: Handler a -> Handler a + dryRunMiddleware handler = do + dryRun <- isDryRun + if | dryRun -> do + hData <- ask + prevState <- readIORef (handlerState hData) + let + restoreSession = + modifyIORef (handlerState hData) $ + \hst -> hst { ghsSession = ghsSession prevState + , ghsCache = ghsCache prevState + , ghsCacheBy = ghsCacheBy prevState + } + site' = (rheSite $ handlerEnv hData) { appMemcached = Nothing } + handler' = local (\hd -> hd { handlerEnv = (handlerEnv hd) { rheSite = site', rheChild = site' } }) handler + + addCustomHeader HeaderDryRun ("1" :: Text) + + handler' `finally` restoreSession + | otherwise -> handler updateFavouritesMiddleware :: Handler a -> Handler a updateFavouritesMiddleware handler = (*> handler) . runMaybeT $ do route <- MaybeT getCurrentRoute @@ -1637,7 +1738,7 @@ instance Yesod UniWorX where storeBearerMiddleware handler = do askBearer >>= \case Just (Jwt bs) -> setSessionBS (toPathPiece SessionBearer) bs - Nothing -> return () + Nothing -> return () handler @@ -4458,16 +4559,34 @@ routeNormalizers = return newRoute +runDBRead :: ReaderT SqlReadBackend Handler a -> Handler a +runDBRead action = do + $logDebugS "YesodPersist" "runDBRead" + runSqlPool (withReaderT SqlReadBackend action) =<< appConnPool <$> getYesod + -- How to run database actions. instance YesodPersist UniWorX where type YesodPersistBackend UniWorX = SqlBackend runDB action = do + -- stack <- liftIO currentCallStack + -- $logDebugS "YesodPersist" . unlines $ "runDB" : map pack stack $logDebugS "YesodPersist" "runDB" - runSqlPool action =<< appConnPool <$> getYesod + dryRun <- isDryRun + let action' + | dryRun = action <* transactionUndo + | otherwise = action + runSqlPool action' =<< appConnPool <$> getYesod instance YesodPersistRunner UniWorX where getDBRunner = do (DBRunner{..}, cleanup) <- defaultGetDBRunner appConnPool - return . (, cleanup) $ DBRunner (\act -> $logDebugS "YesodPersist" "runDBRunner" >> runDBRunner act) + return . (, cleanup) $ DBRunner (\action -> do + dryRun <- isDryRun + let action' + | dryRun = action <* transactionUndo + | otherwise = action + $logDebugS "YesodPersist" "runDBRunner" + runDBRunner action' + ) data CampusUserConversionException = CampusUserInvalidIdent @@ -4952,7 +5071,8 @@ instance YesodAuth UniWorX where campusUserFailoverMode :: FailoverMode campusUserFailoverMode = FailoverUnlimited -instance YesodAuthPersist UniWorX +instance YesodAuthPersist UniWorX where + getAuthEntity = liftHandler . runDBRead . get unsafeHandler :: UniWorX -> Handler a -> IO a diff --git a/src/Handler/Profile.hs b/src/Handler/Profile.hs index 21a31b27d..7efe43acd 100644 --- a/src/Handler/Profile.hs +++ b/src/Handler/Profile.hs @@ -12,7 +12,6 @@ import Import import Handler.Utils import Handler.Utils.Profile -import Handler.Utils.Tokens -- import Colonnade hiding (fromMaybe, singleton) -- import Yesod.Colonnade diff --git a/src/Handler/Users.hs b/src/Handler/Users.hs index bccb6d646..6340ce78a 100644 --- a/src/Handler/Users.hs +++ b/src/Handler/Users.hs @@ -9,7 +9,6 @@ import Import import Jobs -- import Data.Text import Handler.Utils -import Handler.Utils.Tokens import Handler.Utils.Users import Handler.Utils.Invitations diff --git a/src/Handler/Utils/Invitations.hs b/src/Handler/Utils/Invitations.hs index 47873d036..6697f94cb 100644 --- a/src/Handler/Utils/Invitations.hs +++ b/src/Handler/Utils/Invitations.hs @@ -21,8 +21,6 @@ import Import import Utils.Form import Jobs.Queue -import Handler.Utils.Tokens - import Text.Hamlet import qualified Data.Conduit.List as C diff --git a/src/Handler/Utils/Tokens.hs b/src/Handler/Utils/Tokens.hs deleted file mode 100644 index 83266119f..000000000 --- a/src/Handler/Utils/Tokens.hs +++ /dev/null @@ -1,47 +0,0 @@ -module Handler.Utils.Tokens - ( maybeBearerToken, requireBearerToken - , maybeCurrentBearerRestrictions, requireCurrentBearerRestrictions - ) where - -import Import - - -maybeBearerToken :: (MonadHandler m, HandlerSite m ~ UniWorX, MonadCatch m) => m (Maybe (BearerToken UniWorX)) -maybeBearerToken = runMaybeT $ catchIfMaybeT cPred requireBearerToken - where - cPred err = any ($ err) - [ is $ _HCError . _PermissionDenied - , is $ _HCError . _NotAuthenticated - ] - -requireBearerToken :: (MonadHandler m, HandlerSite m ~ UniWorX) => m (BearerToken UniWorX) -requireBearerToken = liftHandler $ do - bearer <- exceptT (guardAuthResult >=> error "askToken should not throw `Authorized`") return askBearerUnsafe - mAuthId <- maybeAuthId - currentRoute <- maybe (permissionDeniedI MsgUnauthorizedToken404) return =<< getCurrentRoute - isWrite <- isWriteRequest currentRoute - guardAuthResult <=< runDB $ validateBearer mAuthId currentRoute isWrite bearer - return bearer - -requireCurrentBearerRestrictions :: ( MonadHandler m - , HandlerSite m ~ UniWorX - , FromJSON a - , ToJSON a - ) - => m (Maybe a) -requireCurrentBearerRestrictions = runMaybeT $ do - bearer <- requireBearerToken - route <- MaybeT getCurrentRoute - hoistMaybe $ bearer ^? _bearerRestrictionIx route - -maybeCurrentBearerRestrictions :: ( MonadHandler m - , HandlerSite m ~ UniWorX - , MonadCatch m - , FromJSON a - , ToJSON a - ) - => m (Maybe a) -maybeCurrentBearerRestrictions = runMaybeT $ do - bearer <- MaybeT maybeBearerToken - route <- MaybeT getCurrentRoute - hoistMaybe $ bearer ^? _bearerRestrictionIx route diff --git a/src/Utils.hs b/src/Utils.hs index 513f38e68..da20821a1 100644 --- a/src/Utils.hs +++ b/src/Utils.hs @@ -564,6 +564,9 @@ catchMaybeT _ act = catch (lift act) (const mzero :: e -> MaybeT m a) catchMPlus :: forall p m e a. (MonadPlus m, MonadCatch m, Exception e) => p e -> m a -> m a catchMPlus _ = handle (const mzero :: e -> m a) + +catchIfMPlus :: forall m e a. (MonadPlus m, MonadCatch m, Exception e) => (e -> Bool) -> m a -> m a +catchIfMPlus p act = catchIf p act (const mzero) mcons :: Maybe a -> [a] -> [a] mcons Nothing xs = xs @@ -841,7 +844,7 @@ choice = foldr (<|>) empty -- Custom HTTP Headers -- --------------------------------- -data CustomHeader = HeaderIsModal | HeaderDBTableShortcircuit | HeaderMassInputShortcircuit | HeaderAlerts | HeaderDBTableCanonicalURL +data CustomHeader = HeaderIsModal | HeaderDBTableShortcircuit | HeaderMassInputShortcircuit | HeaderAlerts | HeaderDBTableCanonicalURL | HeaderDryRun deriving (Eq, Ord, Enum, Bounded, Read, Show, Generic) instance Universe CustomHeader diff --git a/src/Utils/Parameters.hs b/src/Utils/Parameters.hs index 5d2018391..2ca4e2573 100644 --- a/src/Utils/Parameters.hs +++ b/src/Utils/Parameters.hs @@ -21,7 +21,7 @@ import Data.Universe import Control.Monad.Trans.Maybe (MaybeT(..)) -data GlobalGetParam = GetLang | GetReferer | GetBearer | GetRecipient | GetCsvExampleData +data GlobalGetParam = GetLang | GetReferer | GetBearer | GetRecipient | GetCsvExampleData | GetDryRun deriving (Eq, Ord, Enum, Bounded, Read, Show, Generic) deriving anyclass (Universe, Finite) @@ -62,6 +62,7 @@ data GlobalPostParam = PostFormIdentifier | PostLoginDummy | PostExamAutoOccurrencePrevious | PostLanguage + | PostDryRun deriving (Eq, Ord, Enum, Bounded, Read, Show, Generic) deriving anyclass (Universe, Finite) diff --git a/stack.yaml b/stack.yaml index 7c65d9ff1..d48ec8d06 100644 --- a/stack.yaml +++ b/stack.yaml @@ -106,5 +106,7 @@ extra-deps: - token-bucket-0.1.0.1 + - normaldistribution-1.1.0.3 + resolver: lts-15.12 allow-newer: true diff --git a/stack.yaml.lock b/stack.yaml.lock index e9fad1f50..bda38de91 100644 --- a/stack.yaml.lock +++ b/stack.yaml.lock @@ -288,6 +288,13 @@ packages: sha256: b0b4a08ea1bf76bd108310f64d7f80e0f30b61ddc3d71f6cab7bdce329d2c1fa original: hackage: token-bucket-0.1.0.1 +- completed: + hackage: normaldistribution-1.1.0.3@sha256:2615b784c4112cbf6ffa0e2b55b76790290a9b9dff18a05d8c89aa374b213477,2160 + pantry-tree: + size: 269 + sha256: 856818862d12df8b030fa9cfef2c4ffa604d06f0eb057498db245dfffcd60e3c + original: + hackage: normaldistribution-1.1.0.3 snapshots: - completed: size: 494635 diff --git a/test/Load.hs b/test/Load.hs new file mode 100644 index 000000000..7533b6f13 --- /dev/null +++ b/test/Load.hs @@ -0,0 +1,181 @@ +{-# OPTIONS_GHC -fno-warn-unused-top-binds #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Load + ( main + ) where + +import "uniworx" Import hiding (Option(..), Normal) + +import System.Console.GetOpt + +import qualified Data.Text as Text + +import qualified Data.Map.Strict as Map + +import Data.Random.Normal +import qualified Control.Monad.Random.Class as Random +import System.Random (RandomGen) + +import System.Exit (exitWith, ExitCode(..)) +import System.IO (hPutStrLn) + +import UnliftIO.Concurrent (threadDelay) + +import System.Clock (getTime, Clock(Monotonic)) +import qualified System.Clock as Clock + + +data Normal k = Normal + { dAvg :: k + , dRelDev :: Centi + } deriving (Eq, Ord, Read, Show, Generic, Typeable) + +sampleN :: (Random.MonadSplit g m, RandomGen g) => (k -> Centi -> k) -> Normal k -> m k +sampleN scale Normal{..} + | dRelDev == 0 = return dAvg + | otherwise = do + gen <- Random.getSplit + let (realToFrac -> r, _) = normal' (1, realToFrac dRelDev :: Double) gen + return $ dAvg `scale` r + +instance PathPiece k => PathPiece (Normal k) where + toPathPiece Normal{dRelDev = MkFixed perc, dAvg} + | perc == 0 = toPathPiece dAvg + | otherwise = toPathPiece dAvg <> ";" <> toPathPiece perc <> "%" + fromPathPiece t + | (avg, relDev') <- Text.breakOn ";" t + , Just relDev <- Text.stripSuffix "%" =<< Text.stripPrefix ";" relDev' + = Normal <$> fromPathPiece avg <*> (MkFixed <$> fromPathPiece relDev) + | otherwise + = Normal <$> fromPathPiece t <*> pure 0 + +scaleDiffTime :: DiffTime -> Centi -> DiffTime +scaleDiffTime (diffTimeToPicoseconds -> ps) s = picosecondsToDiffTime . round $ s * fromIntegral ps + +sampleNDiffTime :: (Random.MonadSplit g m, RandomGen g) => Normal DiffTime -> m DiffTime +sampleNDiffTime = sampleN scaleDiffTime + + +instance PathPiece DiffTime where + toPathPiece = toPathPiece . MkFixed @E12 . diffTimeToPicoseconds + fromPathPiece t = fromPathPiece t <&> \(MkFixed ps :: Pico) -> picosecondsToDiffTime ps + + +data LoadSimulation + = LoadSheetSubmission + deriving (Eq, Ord, Read, Show, Enum, Bounded, Generic, Typeable) + deriving anyclass (Universe, Finite) + +nullaryPathPiece ''LoadSimulation $ camelToPathPiece' 1 + +data LoadOptions = LoadOptions + { loadSimulations :: Map LoadSimulation SimulationOptions + } deriving (Eq, Ord, Show, Generic, Typeable) + +instance Default LoadOptions where + def = LoadOptions + { loadSimulations = Map.empty + } + +data SimulationOptions = SimulationOptions + { simParallel :: Natural + , simDelay, simDuration :: Normal DiffTime + } deriving (Eq, Ord, Show, Generic, Typeable) + +instance Default SimulationOptions where + def = SimulationOptions + { simParallel = 1 + , simDelay = Normal 0 0 + , simDuration = Normal 10 0 + } + + +data SimulationContext = SimulationContext + { loadOptions :: LoadOptions + , simulationOptions :: SimulationOptions + , targetDuration :: DiffTime + , runtime :: forall m. MonadIO m => m DiffTime + } + + +makeLenses_ ''LoadOptions +makeLenses_ ''SimulationOptions +makeLenses_ ''SimulationContext + + +_MapF :: (Finite k, Ord k) => Iso' (Map k v) (k -> Maybe v) +_MapF = iso (flip Map.lookup) (\f -> Map.fromList $ mapMaybe (\k -> (k, ) <$> f k) universeF) + + +argsDescr :: [OptDescr (Endo LoadOptions)] +argsDescr + = [ Option ['n', 'p'] ["number", "parallel"] (ReqArg (\(splitArg -> (cloneIndexedTraversal -> f, arg)) -> Endo . over f $ set _simParallel arg) "NATURAL") "Number of simulations to run in parallel" + , Option ['r'] ["run"] (ReqArg (\(ppArg -> sim) -> Endo $ over (_loadSimulations . at sim) (<|> Just def)) "SIMULATION") "Run the given Simulation" + , Option ['d'] ["duration"] (ReqArg (\(splitArg -> (cloneIndexedTraversal -> f, arg)) -> Endo . over f $ set _simDuration arg) "DURATION") "Try to run each simulation to take up the given duration" + , Option ['w', 's'] ["wait", "delay", "stagger"] (ReqArg (\(splitArg -> (cloneIndexedTraversal -> f, arg)) -> Endo . over f $ set _simDelay arg) "DURATION") "Wait the given time before starting each simulation" + ] + where + splitArg :: PathPiece p => String -> (AnIndexedTraversal' LoadSimulation LoadOptions SimulationOptions, p) + splitArg (Text.pack -> t) + | (ref, arg) <- Text.breakOn ":" t + , let refs = Text.splitOn "," ref + sArg = Text.stripPrefix ":" arg + , Just refs' <- if | is _Just sArg -> mapM fromPathPiece refs + | otherwise -> Just [] + , Just arg' <- fromPathPiece $ fromMaybe ref sArg + = (, arg') $ if + | null refs' -> _loadSimulations . itraversed + | otherwise -> _loadSimulations . _MapF . itraversed . indices (`elem` refs') . iplens (fromMaybe def) (const Just) + | otherwise + = terror $ "Invalid option argument: " <> t + + ppArg :: PathPiece p => String -> p + ppArg (Text.pack -> a) = fromMaybe (terror $ "Invalid option argument: " <> a) $ fromPathPiece a + +main :: IO () +main = do + args <- map unpack <$> getArgs + case over _1 (over _loadSimulations (Map.filter $ (> 0) . simParallel) . (`appEndo` def) . getDual . foldMap Dual) $ getOpt Permute argsDescr args of + (cfg, [], []) | not . Map.null $ loadSimulations cfg + -> imapM_ (\sim simOpts -> runReaderT (runSimulation sim) (cfg & _loadSimulations . at sim .~ Nothing, simOpts)) $ loadSimulations cfg + (_, _, errs) -> do + forM_ errs $ hPutStrLn stderr + hPutStrLn stderr $ usageInfo "uniworxload" argsDescr + exitWith $ ExitFailure 2 + +runSimulation :: LoadSimulation -> ReaderT (LoadOptions, SimulationOptions) IO () +runSimulation sim = do + p <- view $ _2 . _simParallel + replicateConcurrently_ (fromIntegral p) $ do + d <- view $ _2 . _simDelay + d' <- sampleNDiffTime d + + dur <- view $ _2 . _simDuration + tDuration <- sampleNDiffTime dur + + let MkFixed us = realToFrac d' :: Micro + threadDelay $ fromInteger us + + cTime <- liftIO $ getTime Monotonic + let running :: forall m. MonadIO m => m DiffTime + running = do + cTime' <- liftIO $ getTime Monotonic + let diff = MkFixed . Clock.toNanoSecs $ cTime' - cTime :: Nano + MkFixed ps = realToFrac diff :: Pico + return $ picosecondsToDiffTime ps + + withReaderT (\(lO, sO) -> SimulationContext lO sO tDuration running) $ runSimulation' sim + + +delayRemaining :: (MonadReader SimulationContext m, MonadIO m, RealFrac r) => r -> m () +delayRemaining p = do + total <- asks targetDuration + cTime <- join $ asks runtime + let remaining = MkFixed . diffTimeToPicoseconds $ total - cTime :: Pico + MkFixed us = realToFrac $ realToFrac remaining * p :: Micro + threadDelay $ fromInteger us + + +runSimulation' :: LoadSimulation -> ReaderT SimulationContext IO () +runSimulation' = liftIO . print