119 lines
4.6 KiB
Haskell
119 lines
4.6 KiB
Haskell
-- SPDX-FileCopyrightText: 2022 Gregor Kleen <gregor.kleen@ifi.lmu.de>
|
|
--
|
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
module Utils.PersistentTokenBucket
|
|
( TokenBucketSettings(..)
|
|
, persistentTokenBucketTryAlloc', persistentTokenBucketTakeC'
|
|
, persistentTokenBucketTryAlloc, persistentTokenBucketTakeC
|
|
, persistentTokenBucketRateLimit', persistentTokenBucketRateLimit
|
|
) where
|
|
|
|
import Import.NoFoundation
|
|
|
|
import qualified Data.Conduit.Combinators as C
|
|
|
|
import Control.Concurrent.STM.Delay
|
|
|
|
|
|
data TokenBucketSettings = TokenBucketSettings
|
|
{ tbsIdent :: TokenBucketIdent
|
|
, tbsDepth :: Word64
|
|
, tbsInvRate :: NominalDiffTime
|
|
, tbsInitialValue :: Int64
|
|
}
|
|
|
|
|
|
persistentTokenBucketTryAlloc' :: (MonadHandler m, HasAppSettings (HandlerSite m), Integral a)
|
|
=> TokenBucketIdent
|
|
-> a
|
|
-> SqlPersistT m Bool
|
|
persistentTokenBucketTryAlloc' tbsIdent tokens = do
|
|
TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent)
|
|
persistentTokenBucketTryAlloc TokenBucketSettings
|
|
{ tbsIdent
|
|
, tbsDepth = tokenBucketDepth
|
|
, tbsInvRate = tokenBucketInvRate
|
|
, tbsInitialValue = tokenBucketInitialValue
|
|
} tokens
|
|
|
|
|
|
persistentTokenBucketTryAlloc :: (MonadIO m, Integral a) => TokenBucketSettings -> a -> SqlPersistT m Bool
|
|
persistentTokenBucketTryAlloc TokenBucketSettings{..} (fromIntegral -> tokens) = do
|
|
now <- liftIO getCurrentTime
|
|
TokenBucket{..} <- do
|
|
existingBucket <- get $ TokenBucketKey tbsIdent
|
|
case existingBucket of
|
|
Just bkt -> return bkt
|
|
Nothing -> do
|
|
let bkt = TokenBucket
|
|
{ tokenBucketIdent = tbsIdent
|
|
, tokenBucketLastValue = tbsInitialValue
|
|
, tokenBucketLastAccess = now
|
|
}
|
|
insert_ bkt
|
|
return bkt
|
|
let currentValue = fromIntegral tbsDepth `min` tokenBucketLastValue + tokenIncrease
|
|
deltaT = now `diffUTCTime` tokenBucketLastAccess
|
|
(tokenIncrease, deltaT')
|
|
| n < 0 = (pred n, (1 + f) * tbsInvRate)
|
|
| otherwise = (n, f * tbsInvRate)
|
|
where (n, f) = properFraction $ deltaT / tbsInvRate
|
|
if | currentValue < 0 -> return False
|
|
| otherwise -> do
|
|
update (TokenBucketKey tbsIdent) [ TokenBucketLastValue =. currentValue - tokens, TokenBucketLastAccess =. addUTCTime (- deltaT') now ]
|
|
return True
|
|
|
|
|
|
persistentTokenBucketTakeC' :: forall i m a.
|
|
(MonadHandler m, HasAppSettings (HandlerSite m), Integral a)
|
|
=> TokenBucketIdent
|
|
-> (i -> a)
|
|
-> ConduitT i i (ReaderT SqlBackend m) ()
|
|
persistentTokenBucketTakeC' tbsIdent cTokens = do
|
|
TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent)
|
|
persistentTokenBucketTakeC TokenBucketSettings
|
|
{ tbsIdent
|
|
, tbsDepth = tokenBucketDepth
|
|
, tbsInvRate = tokenBucketInvRate
|
|
, tbsInitialValue = tokenBucketInitialValue
|
|
} cTokens
|
|
|
|
persistentTokenBucketTakeC :: forall i m a.
|
|
(MonadIO m, Integral a)
|
|
=> TokenBucketSettings
|
|
-> (i -> a)
|
|
-> ConduitT i i (ReaderT SqlBackend m) ()
|
|
persistentTokenBucketTakeC tbs cTokens = C.mapAccumWhileM tbAccum ()
|
|
where tbAccum :: i
|
|
-> ()
|
|
-> SqlPersistT m (Either () ((), i))
|
|
tbAccum x ()
|
|
= bool (Left ()) (Right ((), x)) <$> persistentTokenBucketTryAlloc tbs (cTokens x)
|
|
|
|
|
|
persistentTokenBucketRateLimit :: forall i m a.
|
|
( MonadIO m, Integral a )
|
|
=> TokenBucketSettings
|
|
-> (i -> a)
|
|
-> ConduitT i i m ()
|
|
persistentTokenBucketRateLimit TokenBucketSettings{tbsInvRate} cTokens = awaitForever $ \x@(cTokens -> s) -> do
|
|
yield x
|
|
let
|
|
MkFixed (fromIntegral -> dTime) = (realToFrac $ fromIntegral s * tbsInvRate :: Micro)
|
|
liftIO $ atomically . waitDelay =<< newDelay dTime
|
|
|
|
persistentTokenBucketRateLimit' :: forall i m a.
|
|
(MonadHandler m, HasAppSettings (HandlerSite m), Integral a)
|
|
=> TokenBucketIdent
|
|
-> (i -> a)
|
|
-> ConduitT i i m ()
|
|
persistentTokenBucketRateLimit' tbsIdent cTokens = do
|
|
TokenBucketConf{..} <- getsYesod $ views _appPersistentTokenBuckets ($ tbsIdent)
|
|
persistentTokenBucketRateLimit TokenBucketSettings
|
|
{ tbsIdent
|
|
, tbsDepth = tokenBucketDepth
|
|
, tbsInvRate = tokenBucketInvRate
|
|
, tbsInitialValue = tokenBucketInitialValue
|
|
} cTokens
|