-- SPDX-FileCopyrightText: 2022 Gregor Kleen -- -- SPDX-License-Identifier: AGPL-3.0-or-later module Utils.PersistentTokenBucket ( TokenBucketSettings(..) , persistentTokenBucketTryAlloc', persistentTokenBucketTakeC' , persistentTokenBucketTryAlloc, persistentTokenBucketTakeC , persistentTokenBucketRateLimit', persistentTokenBucketRateLimit ) where import Import.NoFoundation import qualified Data.Conduit.Combinators as C import Control.Concurrent.STM.Delay data TokenBucketSettings = TokenBucketSettings { tbsIdent :: TokenBucketIdent , tbsDepth :: Word64 , tbsInvRate :: NominalDiffTime , tbsInitialValue :: Int64 } persistentTokenBucketTryAlloc' :: (MonadHandler m, HasAppSettings (HandlerSite m), Integral a) => TokenBucketIdent -> a -> SqlPersistT m Bool persistentTokenBucketTryAlloc' tbsIdent tokens = do TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent) persistentTokenBucketTryAlloc TokenBucketSettings { tbsIdent , tbsDepth = tokenBucketDepth , tbsInvRate = tokenBucketInvRate , tbsInitialValue = tokenBucketInitialValue } tokens persistentTokenBucketTryAlloc :: (MonadIO m, Integral a) => TokenBucketSettings -> a -> SqlPersistT m Bool persistentTokenBucketTryAlloc TokenBucketSettings{..} (fromIntegral -> tokens) = do now <- liftIO getCurrentTime TokenBucket{..} <- do existingBucket <- get $ TokenBucketKey tbsIdent case existingBucket of Just bkt -> return bkt Nothing -> do let bkt = TokenBucket { tokenBucketIdent = tbsIdent , tokenBucketLastValue = tbsInitialValue , tokenBucketLastAccess = now } insert_ bkt return bkt let currentValue = fromIntegral tbsDepth `min` tokenBucketLastValue + tokenIncrease deltaT = now `diffUTCTime` tokenBucketLastAccess (tokenIncrease, deltaT') | n < 0 = (pred n, (1 + f) * tbsInvRate) | otherwise = (n, f * tbsInvRate) where (n, f) = properFraction $ deltaT / tbsInvRate if | currentValue < 0 -> return False | otherwise -> do update (TokenBucketKey tbsIdent) [ TokenBucketLastValue =. currentValue - tokens, TokenBucketLastAccess =. addUTCTime (- deltaT') now ] return True persistentTokenBucketTakeC' :: forall i m a. (MonadHandler m, HasAppSettings (HandlerSite m), Integral a) => TokenBucketIdent -> (i -> a) -> ConduitT i i (ReaderT SqlBackend m) () persistentTokenBucketTakeC' tbsIdent cTokens = do TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent) persistentTokenBucketTakeC TokenBucketSettings { tbsIdent , tbsDepth = tokenBucketDepth , tbsInvRate = tokenBucketInvRate , tbsInitialValue = tokenBucketInitialValue } cTokens persistentTokenBucketTakeC :: forall i m a. (MonadIO m, Integral a) => TokenBucketSettings -> (i -> a) -> ConduitT i i (ReaderT SqlBackend m) () persistentTokenBucketTakeC tbs cTokens = C.mapAccumWhileM tbAccum () where tbAccum :: i -> () -> SqlPersistT m (Either () ((), i)) tbAccum x () = bool (Left ()) (Right ((), x)) <$> persistentTokenBucketTryAlloc tbs (cTokens x) persistentTokenBucketRateLimit :: forall i m a. ( MonadIO m, Integral a ) => TokenBucketSettings -> (i -> a) -> ConduitT i i m () persistentTokenBucketRateLimit TokenBucketSettings{tbsInvRate} cTokens = awaitForever $ \x@(cTokens -> s) -> do yield x let MkFixed (fromIntegral -> dTime) = (realToFrac $ fromIntegral s * tbsInvRate :: Micro) liftIO $ atomically . waitDelay =<< newDelay dTime persistentTokenBucketRateLimit' :: forall i m a. (MonadHandler m, HasAppSettings (HandlerSite m), Integral a) => TokenBucketIdent -> (i -> a) -> ConduitT i i m () persistentTokenBucketRateLimit' tbsIdent cTokens = do TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent) persistentTokenBucketRateLimit TokenBucketSettings { tbsIdent , tbsDepth = tokenBucketDepth , tbsInvRate = tokenBucketInvRate , tbsInitialValue = tokenBucketInitialValue } cTokens