Skip to content

Commit

Permalink
Compute backoff ourselves and make Attempt.next_wait meaningful (#74)
Browse files Browse the repository at this point in the history
* Remove backoff from RetryingCaller tests

* Compute backoff ourselves and make Attempt.next_wait meaningful

* Ensure both backoff methods clamp correctly
  • Loading branch information
hynek authored Aug 18, 2024
1 parent 87e92f1 commit 8d4c5f7
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 40 deletions.
109 changes: 83 additions & 26 deletions src/stamina/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import datetime as dt
import random
import sys

from dataclasses import dataclass, replace
Expand All @@ -15,6 +16,7 @@
AsyncIterator,
Awaitable,
Callable,
ClassVar,
Iterator,
Tuple,
Type,
Expand Down Expand Up @@ -108,12 +110,17 @@ class Attempt:
.. versionadded:: 23.2.0
"""

__slots__ = ("_t_attempt",)
__slots__ = ("_t_attempt", "_next_wait_fn")

_t_attempt: _t.AttemptManager

def __init__(self, attempt: _t.AttemptManager):
def __init__(
self,
attempt: _t.AttemptManager,
next_wait_fn: Callable[[int], float] | None,
):
self._t_attempt = attempt
self._next_wait_fn = next_wait_fn

def __repr__(self) -> str:
return f"<Attempt num={self.num}, next_wait={float(self.next_wait)}>"
Expand All @@ -131,9 +138,18 @@ def next_wait(self) -> float:
The number of seconds of backoff before the *next* attempt if *this*
attempt fails.
.. warning::
This value does **not** include a possible random jitter and is
therefore just a *lower bound* of the actual value.
.. versionadded:: 24.3.0
"""
return self._t_attempt.retry_state.upcoming_sleep # type: ignore[no-any-return]
return (
self._next_wait_fn(self._t_attempt.retry_state.attempt_number + 1)
if self._next_wait_fn
else 0.0
)

def __enter__(self) -> None:
return self._t_attempt.__enter__() # type: ignore[no-any-return]
Expand Down Expand Up @@ -382,13 +398,30 @@ def __aiter__(self) -> _t.AsyncRetrying:

@dataclass
class _RetryContextIterator:
__slots__ = ("_t_kw", "_t_a_retrying", "_name", "_args", "_kw")
__slots__ = (
"_t_kw",
"_t_a_retrying",
"_name",
"_args",
"_kw",
"_wait_jitter",
"_wait_initial",
"_wait_max",
"_wait_exp_base",
)
_t_kw: dict[str, object]
_t_a_retrying: _t.AsyncRetrying
_name: str
_args: tuple[object, ...]
_kw: dict[str, object]

_wait_jitter: float
_wait_initial: float
_wait_max: float
_wait_exp_base: float

_random: ClassVar[random.Random] = random.Random() # noqa: S311

@classmethod
def from_params(
cls,
Expand All @@ -411,30 +444,26 @@ def from_params(
_retry = _t.retry_if_exception_type(on)
else:
_retry = _t.retry_if_exception(on)
return cls(

if isinstance(wait_initial, dt.timedelta):
wait_initial = wait_initial.total_seconds()

if isinstance(wait_max, dt.timedelta):
wait_max = wait_max.total_seconds()

if isinstance(wait_jitter, dt.timedelta):
wait_jitter = wait_jitter.total_seconds()

inst = cls(
_name=name,
_args=args,
_kw=kw,
_wait_jitter=wait_jitter,
_wait_initial=wait_initial,
_wait_max=wait_max,
_wait_exp_base=wait_exp_base,
_t_kw={
"retry": _retry,
"wait": _t.wait_exponential_jitter(
initial=(
wait_initial.total_seconds()
if isinstance(wait_initial, dt.timedelta)
else wait_initial
),
max=(
wait_max.total_seconds()
if isinstance(wait_max, dt.timedelta)
else wait_max
),
exp_base=wait_exp_base,
jitter=(
wait_jitter.total_seconds()
if isinstance(wait_jitter, dt.timedelta)
else wait_jitter
),
),
"stop": _make_stop(
attempts=attempts,
timeout=(
Expand All @@ -448,6 +477,10 @@ def from_params(
_t_a_retrying=_LAZY_NO_ASYNC_RETRY,
)

inst._t_kw["wait"] = inst._jittered_backoff_for_rcs

return inst

def with_name(
self, name: str, args: tuple[object, ...], kw: dict[str, object]
) -> _RetryContextIterator:
Expand All @@ -459,7 +492,7 @@ def with_name(
def __iter__(self) -> Iterator[Attempt]:
if not CONFIG.is_active:
for r in _t.Retrying(reraise=True, stop=_STOP_NO_RETRY):
yield Attempt(r)
yield Attempt(r, None)

return

Expand All @@ -469,7 +502,7 @@ def __iter__(self) -> Iterator[Attempt]:
),
**self._t_kw,
):
yield Attempt(r)
yield Attempt(r, self._backoff_for_attempt_number)

def __aiter__(self) -> AsyncIterator[Attempt]:
if CONFIG.is_active:
Expand All @@ -486,7 +519,31 @@ def __aiter__(self) -> AsyncIterator[Attempt]:
return self

async def __anext__(self) -> Attempt:
return Attempt(await self._t_a_retrying.__anext__())
return Attempt(
await self._t_a_retrying.__anext__(),
self._backoff_for_attempt_number,
)

def _backoff_for_attempt_number(self, num: int) -> float:
"""
Compute a jitter-less lower bound for backoff number *num*.
*num* is 1-based.
"""
return min(
self._wait_max,
self._wait_initial * (self._wait_exp_base ** (num - 1)),
)

def _jittered_backoff_for_rcs(self, rcs: _t.RetryCallState) -> float:
"""
Compute the backoff for *rcs*.
"""
return min(
self._wait_max,
self._backoff_for_attempt_number(rcs.attempt_number)
+ self._random.uniform(0, self._wait_jitter),
)


def _make_before_sleep(
Expand Down
11 changes: 3 additions & 8 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,13 @@ async def test_next_wait():
"""
The next_wait property is updated.
"""
i = 0

async for attempt in stamina.retry_context(on=ValueError, wait_max=0.001):
async for attempt in stamina.retry_context(on=ValueError, wait_max=0.0001):
with attempt:
if i == 0:
assert 0.0 == attempt.next_wait
assert pytest.approx(0.0001) == attempt.next_wait

i += 1
if attempt.num == 1:
raise ValueError

assert pytest.approx(0.001) == attempt.next_wait


async def test_retry_blocks_can_be_disabled():
"""
Expand Down
27 changes: 21 additions & 6 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import datetime as dt

from types import SimpleNamespace

import pytest
import tenacity

Expand Down Expand Up @@ -171,17 +173,30 @@ def test_next_wait():
"""
The next_wait property is updated.
"""
i = 0

for attempt in stamina.retry_context(on=ValueError, wait_max=0.001):
for attempt in stamina.retry_context(on=ValueError, wait_max=0.0001):
with attempt:
if i == 0:
assert 0.0 == attempt.next_wait
assert pytest.approx(0.0001) == attempt.next_wait

i += 1
if attempt.num == 1:
raise ValueError

assert pytest.approx(0.001) == attempt.next_wait

def test_backoff_computation_clamps():
"""
The backoff returned by _RetryContextIterator._backoff_for_attempt_number
and _RetryContextIterator._jittered_backoff_for_rcs never exceeds wait_max.
"""
rci = stamina.retry_context(on=ValueError, wait_max=0.42)

for i in range(1, 10):
backoff = rci._backoff_for_attempt_number(i)
assert backoff <= 0.42

jittered = rci._jittered_backoff_for_rcs(
SimpleNamespace(attempt_number=i)
)
assert jittered <= 0.42


class TestMakeStop:
Expand Down

0 comments on commit 8d4c5f7

Please sign in to comment.