fradrive/src/Utils/PersistentTokenBucket.hs
2022-10-12 09:35:16 +02:00

119 lines
4.6 KiB
Haskell

-- SPDX-FileCopyrightText: 2022 Gregor Kleen <gregor.kleen@ifi.lmu.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
module Utils.PersistentTokenBucket
( TokenBucketSettings(..)
, persistentTokenBucketTryAlloc', persistentTokenBucketTakeC'
, persistentTokenBucketTryAlloc, persistentTokenBucketTakeC
, persistentTokenBucketRateLimit', persistentTokenBucketRateLimit
) where
import Import.NoFoundation
import qualified Data.Conduit.Combinators as C
import Control.Concurrent.STM.Delay
data TokenBucketSettings = TokenBucketSettings
{ tbsIdent :: TokenBucketIdent
, tbsDepth :: Word64
, tbsInvRate :: NominalDiffTime
, tbsInitialValue :: Int64
}
persistentTokenBucketTryAlloc' :: (MonadHandler m, HasAppSettings (HandlerSite m), Integral a)
=> TokenBucketIdent
-> a
-> SqlPersistT m Bool
persistentTokenBucketTryAlloc' tbsIdent tokens = do
TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent)
persistentTokenBucketTryAlloc TokenBucketSettings
{ tbsIdent
, tbsDepth = tokenBucketDepth
, tbsInvRate = tokenBucketInvRate
, tbsInitialValue = tokenBucketInitialValue
} tokens
persistentTokenBucketTryAlloc :: (MonadIO m, Integral a) => TokenBucketSettings -> a -> SqlPersistT m Bool
persistentTokenBucketTryAlloc TokenBucketSettings{..} (fromIntegral -> tokens) = do
now <- liftIO getCurrentTime
TokenBucket{..} <- do
existingBucket <- get $ TokenBucketKey tbsIdent
case existingBucket of
Just bkt -> return bkt
Nothing -> do
let bkt = TokenBucket
{ tokenBucketIdent = tbsIdent
, tokenBucketLastValue = tbsInitialValue
, tokenBucketLastAccess = now
}
insert_ bkt
return bkt
let currentValue = fromIntegral tbsDepth `min` tokenBucketLastValue + tokenIncrease
deltaT = now `diffUTCTime` tokenBucketLastAccess
(tokenIncrease, deltaT')
| n < 0 = (pred n, (1 + f) * tbsInvRate)
| otherwise = (n, f * tbsInvRate)
where (n, f) = properFraction $ deltaT / tbsInvRate
if | currentValue < 0 -> return False
| otherwise -> do
update (TokenBucketKey tbsIdent) [ TokenBucketLastValue =. currentValue - tokens, TokenBucketLastAccess =. addUTCTime (- deltaT') now ]
return True
persistentTokenBucketTakeC' :: forall i m a.
(MonadHandler m, HasAppSettings (HandlerSite m), Integral a)
=> TokenBucketIdent
-> (i -> a)
-> ConduitT i i (ReaderT SqlBackend m) ()
persistentTokenBucketTakeC' tbsIdent cTokens = do
TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent)
persistentTokenBucketTakeC TokenBucketSettings
{ tbsIdent
, tbsDepth = tokenBucketDepth
, tbsInvRate = tokenBucketInvRate
, tbsInitialValue = tokenBucketInitialValue
} cTokens
persistentTokenBucketTakeC :: forall i m a.
(MonadIO m, Integral a)
=> TokenBucketSettings
-> (i -> a)
-> ConduitT i i (ReaderT SqlBackend m) ()
persistentTokenBucketTakeC tbs cTokens = C.mapAccumWhileM tbAccum ()
where tbAccum :: i
-> ()
-> SqlPersistT m (Either () ((), i))
tbAccum x ()
= bool (Left ()) (Right ((), x)) <$> persistentTokenBucketTryAlloc tbs (cTokens x)
persistentTokenBucketRateLimit :: forall i m a.
( MonadIO m, Integral a )
=> TokenBucketSettings
-> (i -> a)
-> ConduitT i i m ()
persistentTokenBucketRateLimit TokenBucketSettings{tbsInvRate} cTokens = awaitForever $ \x@(cTokens -> s) -> do
yield x
let
MkFixed (fromIntegral -> dTime) = (realToFrac $ fromIntegral s * tbsInvRate :: Micro)
liftIO $ atomically . waitDelay =<< newDelay dTime
persistentTokenBucketRateLimit' :: forall i m a.
(MonadHandler m, HasAppSettings (HandlerSite m), Integral a)
=> TokenBucketIdent
-> (i -> a)
-> ConduitT i i m ()
persistentTokenBucketRateLimit' tbsIdent cTokens = do
TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent)
persistentTokenBucketRateLimit TokenBucketSettings
{ tbsIdent
, tbsDepth = tokenBucketDepth
, tbsInvRate = tokenBucketInvRate
, tbsInitialValue = tokenBucketInitialValue
} cTokens