diff --git a/hyx/ratelimit/__init__.py b/hyx/ratelimit/__init__.py index d301ab9..a587039 100644 --- a/hyx/ratelimit/__init__.py +++ b/hyx/ratelimit/__init__.py @@ -1,8 +1,10 @@ from hyx.ratelimit.api import ratelimiter, tokenbucket +from hyx.ratelimit.buckets import TokenBucket from hyx.ratelimit.managers import TokenBucketLimiter __all__ = ( "ratelimiter", "tokenbucket", "TokenBucketLimiter", + "TokenBucket", ) diff --git a/hyx/ratelimit/buckets.py b/hyx/ratelimit/buckets.py new file mode 100644 index 0000000..e108bc7 --- /dev/null +++ b/hyx/ratelimit/buckets.py @@ -0,0 +1,84 @@ +import asyncio +from typing import Optional + +from hyx.ratelimit.exceptions import EmptyBucket + + +class TokenBucket: + """ + Token Bucket Logic + Replenish tokens as time passes on. If tokens are available, executions can be allowed. + Otherwise, it's going to be rejected with an EmptyBucket error + """ + + __slots__ = ( + "_max_executions", + "_per_time_secs", + "_bucket_size", + "_loop", + "_token_per_secs", + "_tokens", + "_next_replenish_at", + ) + + def __init__(self, max_executions: float, per_time_secs: float, bucket_size: Optional[float] = None) -> None: + self._max_executions = max_executions + self._per_time_secs = per_time_secs + + self._bucket_size = bucket_size if bucket_size else max_executions + + self._loop = asyncio.get_running_loop() + self._token_per_secs = self._per_time_secs / self._max_executions + + self._tokens = self._bucket_size + self._next_replenish_at = self._loop.time() + self._token_per_secs + + @property + def tokens(self) -> float: + self._replenish() + return self._tokens + + @property + def empty(self) -> bool: + self._replenish() + return self._tokens <= 0 + + async def take(self) -> None: + if not self.empty: + self._tokens -= 1 + return + + now = self._loop.time() + + next_replenish = self._next_replenish_at + until_next_replenish = next_replenish - now + + if until_next_replenish > 0: + raise EmptyBucket + + tokens_to_add = min(self._bucket_size, 1 + abs(until_next_replenish / self._token_per_secs)) + + self._next_replenish_at = max( + next_replenish + tokens_to_add * self._token_per_secs, + now + self._token_per_secs, + ) + + self._tokens = tokens_to_add - 1 + return + + def _replenish(self) -> None: + now = self._loop.time() + + next_replenish = self._next_replenish_at + until_next_replenish = next_replenish - now + + if until_next_replenish > 0: + return + + tokens_to_add = min(self._bucket_size, 1 + abs(until_next_replenish / self._token_per_secs)) + self._next_replenish_at = max( + next_replenish + tokens_to_add * self._token_per_secs, + now + self._token_per_secs, + ) + self._tokens = tokens_to_add + return diff --git a/hyx/ratelimit/exceptions.py b/hyx/ratelimit/exceptions.py index 9a5e66e..4d65486 100644 --- a/hyx/ratelimit/exceptions.py +++ b/hyx/ratelimit/exceptions.py @@ -5,3 +5,9 @@ class RateLimitExceeded(HyxError): """ Occurs when requester have exceeded the rate limit """ + + +class EmptyBucket(HyxError): + """ + Occurs when requester have exceeded the rate limit + """ diff --git a/hyx/ratelimit/managers.py b/hyx/ratelimit/managers.py index b6294da..088851f 100644 --- a/hyx/ratelimit/managers.py +++ b/hyx/ratelimit/managers.py @@ -1,7 +1,7 @@ -import asyncio from typing import Optional -from hyx.ratelimit.exceptions import RateLimitExceeded +from hyx.ratelimit.buckets import TokenBucket +from hyx.ratelimit.exceptions import EmptyBucket, RateLimitExceeded class RateLimiter: @@ -16,59 +16,20 @@ class TokenBucketLimiter(RateLimiter): Otherwise, it's going to be rejected with RateLimitExceeded """ - __slots__ = ( - "_max_executions", - "_per_time_secs", - "_bucket_size", - "_loop", - "_token_per_secs", - "_tokens", - "_next_replenish_at", - ) + __slots__ = ("_token_bucket",) def __init__(self, max_executions: float, per_time_secs: float, bucket_size: Optional[float] = None) -> None: - self._max_executions = max_executions - self._per_time_secs = per_time_secs - - self._bucket_size = bucket_size if bucket_size else max_executions - - self._loop = asyncio.get_running_loop() - self._token_per_secs = self._per_time_secs / self._max_executions - - self._tokens = self._bucket_size - self._next_replenish_at = self._loop.time() + self._token_per_secs + self._token_bucket = TokenBucket(max_executions, per_time_secs, bucket_size) @property - def tokens(self) -> float: - return self._tokens - - @property - def empty(self) -> bool: - return self._tokens <= 0 + def bucket(self) -> TokenBucket: + return self._token_bucket async def acquire(self) -> None: - if not self.empty: - self._tokens -= 1 - return - - now = self._loop.time() - - next_replenish = self._next_replenish_at - until_next_replenish = next_replenish - now - - if until_next_replenish > 0: - raise RateLimitExceeded - - tokens_to_add = min(self._bucket_size, 1 + abs(until_next_replenish / self._token_per_secs)) - - self._next_replenish_at = max( - next_replenish + tokens_to_add * self._token_per_secs, - now + self._token_per_secs, - ) - - # account for the current call - self._tokens = tokens_to_add - 1 - return + try: + await self._token_bucket.take() + except EmptyBucket as e: + raise RateLimitExceeded from e class LeakyTokenBucketLimiter(RateLimiter): diff --git a/hyx/retry/api.py b/hyx/retry/api.py index 019c881..efd98b9 100644 --- a/hyx/retry/api.py +++ b/hyx/retry/api.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Optional, Sequence, cast from hyx.events import EventDispatcher, EventManager, get_default_name -from hyx.ratelimit.managers import TokenBucketLimiter +from hyx.ratelimit.buckets import TokenBucket from hyx.retry.events import _RETRY_LISTENERS, RetryListener from hyx.retry.manager import RetryManager from hyx.retry.typing import AttemptsT, BackoffsT, BucketRetryT @@ -77,7 +77,7 @@ def bucket_retry( """ def _decorator(func: FuncT) -> FuncT: - limiter = TokenBucketLimiter(attempts, per_time_secs, bucket_size) if attempts and per_time_secs else None + limiter = TokenBucket(attempts, per_time_secs, bucket_size) if attempts and per_time_secs else None event_dispatcher = EventDispatcher[RetryManager, RetryListener]( listeners, _RETRY_LISTENERS, diff --git a/hyx/retry/manager.py b/hyx/retry/manager.py index 03a726c..111f032 100644 --- a/hyx/retry/manager.py +++ b/hyx/retry/manager.py @@ -1,7 +1,7 @@ import asyncio from typing import Any, Optional -from hyx.ratelimit.managers import TokenBucketLimiter +from hyx.ratelimit.buckets import TokenBucket from hyx.retry.backoffs import create_backoff from hyx.retry.counters import create_counter from hyx.retry.events import RetryListener @@ -28,7 +28,7 @@ def __init__( attempts: AttemptsT, backoff: BackoffsT, event_dispatcher: RetryListener, - limiter: Optional[TokenBucketLimiter] = None, + limiter: Optional[TokenBucket] = None, ) -> None: self._name = name self._exceptions = exceptions @@ -49,7 +49,7 @@ async def __call__(self, func: FuncT) -> Any: while bool(counter): try: if self._limiter is not None: - await self._limiter.acquire() + await self._limiter.take() result = await func() diff --git a/tests/test_ratelimiter/test_buckets.py b/tests/test_ratelimiter/test_buckets.py new file mode 100644 index 0000000..489efc4 --- /dev/null +++ b/tests/test_ratelimiter/test_buckets.py @@ -0,0 +1,35 @@ +import asyncio + +import pytest + +from hyx.ratelimit.buckets import TokenBucket +from hyx.ratelimit.exceptions import EmptyBucket + + +async def test__token_bucket_success() -> None: + bucket = TokenBucket(3, 1, 3) + + for i in range(3): + assert bucket.tokens == (3 - i) + await bucket.take() + assert bucket.empty is True + + +async def test__token_bucket_limit_exceeded() -> None: + bucket = TokenBucket(3, 1, 3) + + with pytest.raises(EmptyBucket): + for _ in range(4): + await bucket.take() + + +async def test__token_bucket__fully_replenish_after_time_period() -> None: + bucket = TokenBucket(3, 1, 3) + + for _ in range(3): + await bucket.take() + + await asyncio.sleep(3) + + assert bucket.tokens == 3 + assert bucket.empty is False