Skip to content

Commit

Permalink
#90 Exposing TokenBucket as a standalone component (#100)
Browse files Browse the repository at this point in the history
* Exposing TokenBucket as a standalone component

* Renamed tokenbucket.py to buckets.py and changed the retry component to use the newly exposed TokenBucket component

* Added in new method replenish to ensure token consistency and added in more token bucket tests

* Fixed styling of test_tokenbucket

* Changed variable/function names and reformatted code for better structure

---------

Co-authored-by: Kevin Yang <[email protected]>
  • Loading branch information
kevkevy3000 and Kevin Yang authored Dec 9, 2023
1 parent 513caa3 commit c104a40
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 54 deletions.
2 changes: 2 additions & 0 deletions hyx/ratelimit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from hyx.ratelimit.api import ratelimiter, tokenbucket
from hyx.ratelimit.buckets import TokenBucket
from hyx.ratelimit.managers import TokenBucketLimiter

__all__ = (
"ratelimiter",
"tokenbucket",
"TokenBucketLimiter",
"TokenBucket",
)
84 changes: 84 additions & 0 deletions hyx/ratelimit/buckets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import asyncio
from typing import Optional

from hyx.ratelimit.exceptions import EmptyBucket


class TokenBucket:
"""
Token Bucket Logic
Replenish tokens as time passes on. If tokens are available, executions can be allowed.
Otherwise, it's going to be rejected with an EmptyBucket error
"""

__slots__ = (
"_max_executions",
"_per_time_secs",
"_bucket_size",
"_loop",
"_token_per_secs",
"_tokens",
"_next_replenish_at",
)

def __init__(self, max_executions: float, per_time_secs: float, bucket_size: Optional[float] = None) -> None:
self._max_executions = max_executions
self._per_time_secs = per_time_secs

self._bucket_size = bucket_size if bucket_size else max_executions

self._loop = asyncio.get_running_loop()
self._token_per_secs = self._per_time_secs / self._max_executions

self._tokens = self._bucket_size
self._next_replenish_at = self._loop.time() + self._token_per_secs

@property
def tokens(self) -> float:
self._replenish()
return self._tokens

@property
def empty(self) -> bool:
self._replenish()
return self._tokens <= 0

async def take(self) -> None:
if not self.empty:
self._tokens -= 1
return

now = self._loop.time()

next_replenish = self._next_replenish_at
until_next_replenish = next_replenish - now

if until_next_replenish > 0:
raise EmptyBucket

tokens_to_add = min(self._bucket_size, 1 + abs(until_next_replenish / self._token_per_secs))

self._next_replenish_at = max(
next_replenish + tokens_to_add * self._token_per_secs,
now + self._token_per_secs,
)

self._tokens = tokens_to_add - 1
return

def _replenish(self) -> None:
now = self._loop.time()

next_replenish = self._next_replenish_at
until_next_replenish = next_replenish - now

if until_next_replenish > 0:
return

tokens_to_add = min(self._bucket_size, 1 + abs(until_next_replenish / self._token_per_secs))
self._next_replenish_at = max(
next_replenish + tokens_to_add * self._token_per_secs,
now + self._token_per_secs,
)
self._tokens = tokens_to_add
return
6 changes: 6 additions & 0 deletions hyx/ratelimit/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ class RateLimitExceeded(HyxError):
"""
Occurs when requester have exceeded the rate limit
"""


class EmptyBucket(HyxError):
"""
Occurs when requester have exceeded the rate limit
"""
59 changes: 10 additions & 49 deletions hyx/ratelimit/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from typing import Optional

from hyx.ratelimit.exceptions import RateLimitExceeded
from hyx.ratelimit.buckets import TokenBucket
from hyx.ratelimit.exceptions import EmptyBucket, RateLimitExceeded


class RateLimiter:
Expand All @@ -16,59 +16,20 @@ class TokenBucketLimiter(RateLimiter):
Otherwise, it's going to be rejected with RateLimitExceeded
"""

__slots__ = (
"_max_executions",
"_per_time_secs",
"_bucket_size",
"_loop",
"_token_per_secs",
"_tokens",
"_next_replenish_at",
)
__slots__ = ("_token_bucket",)

def __init__(self, max_executions: float, per_time_secs: float, bucket_size: Optional[float] = None) -> None:
self._max_executions = max_executions
self._per_time_secs = per_time_secs

self._bucket_size = bucket_size if bucket_size else max_executions

self._loop = asyncio.get_running_loop()
self._token_per_secs = self._per_time_secs / self._max_executions

self._tokens = self._bucket_size
self._next_replenish_at = self._loop.time() + self._token_per_secs
self._token_bucket = TokenBucket(max_executions, per_time_secs, bucket_size)

@property
def tokens(self) -> float:
return self._tokens

@property
def empty(self) -> bool:
return self._tokens <= 0
def bucket(self) -> TokenBucket:
return self._token_bucket

async def acquire(self) -> None:
if not self.empty:
self._tokens -= 1
return

now = self._loop.time()

next_replenish = self._next_replenish_at
until_next_replenish = next_replenish - now

if until_next_replenish > 0:
raise RateLimitExceeded

tokens_to_add = min(self._bucket_size, 1 + abs(until_next_replenish / self._token_per_secs))

self._next_replenish_at = max(
next_replenish + tokens_to_add * self._token_per_secs,
now + self._token_per_secs,
)

# account for the current call
self._tokens = tokens_to_add - 1
return
try:
await self._token_bucket.take()
except EmptyBucket as e:
raise RateLimitExceeded from e


class LeakyTokenBucketLimiter(RateLimiter):
Expand Down
4 changes: 2 additions & 2 deletions hyx/retry/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Optional, Sequence, cast

from hyx.events import EventDispatcher, EventManager, get_default_name
from hyx.ratelimit.managers import TokenBucketLimiter
from hyx.ratelimit.buckets import TokenBucket
from hyx.retry.events import _RETRY_LISTENERS, RetryListener
from hyx.retry.manager import RetryManager
from hyx.retry.typing import AttemptsT, BackoffsT, BucketRetryT
Expand Down Expand Up @@ -77,7 +77,7 @@ def bucket_retry(
"""

def _decorator(func: FuncT) -> FuncT:
limiter = TokenBucketLimiter(attempts, per_time_secs, bucket_size) if attempts and per_time_secs else None
limiter = TokenBucket(attempts, per_time_secs, bucket_size) if attempts and per_time_secs else None
event_dispatcher = EventDispatcher[RetryManager, RetryListener](
listeners,
_RETRY_LISTENERS,
Expand Down
6 changes: 3 additions & 3 deletions hyx/retry/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from typing import Any, Optional

from hyx.ratelimit.managers import TokenBucketLimiter
from hyx.ratelimit.buckets import TokenBucket
from hyx.retry.backoffs import create_backoff
from hyx.retry.counters import create_counter
from hyx.retry.events import RetryListener
Expand All @@ -28,7 +28,7 @@ def __init__(
attempts: AttemptsT,
backoff: BackoffsT,
event_dispatcher: RetryListener,
limiter: Optional[TokenBucketLimiter] = None,
limiter: Optional[TokenBucket] = None,
) -> None:
self._name = name
self._exceptions = exceptions
Expand All @@ -49,7 +49,7 @@ async def __call__(self, func: FuncT) -> Any:
while bool(counter):
try:
if self._limiter is not None:
await self._limiter.acquire()
await self._limiter.take()

result = await func()

Expand Down
35 changes: 35 additions & 0 deletions tests/test_ratelimiter/test_buckets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio

import pytest

from hyx.ratelimit.buckets import TokenBucket
from hyx.ratelimit.exceptions import EmptyBucket


async def test__token_bucket_success() -> None:
bucket = TokenBucket(3, 1, 3)

for i in range(3):
assert bucket.tokens == (3 - i)
await bucket.take()
assert bucket.empty is True


async def test__token_bucket_limit_exceeded() -> None:
bucket = TokenBucket(3, 1, 3)

with pytest.raises(EmptyBucket):
for _ in range(4):
await bucket.take()


async def test__token_bucket__fully_replenish_after_time_period() -> None:
bucket = TokenBucket(3, 1, 3)

for _ in range(3):
await bucket.take()

await asyncio.sleep(3)

assert bucket.tokens == 3
assert bucket.empty is False

0 comments on commit c104a40

Please sign in to comment.