diff --git a/edward2/maps.py b/edward2/maps.py index 6dc09952..6511a1ee 100644 --- a/edward2/maps.py +++ b/edward2/maps.py @@ -38,7 +38,7 @@ def robust_map( max_workers: int | None = ..., raise_error: Literal[False] = ..., retry_exception_types: list[type[Exception]] | None = ..., -) -> Sequence[U | V]: +) -> list[U | V]: ... @@ -52,7 +52,7 @@ def robust_map( max_workers: int | None = ..., raise_error: Literal[True] = ..., retry_exception_types: list[type[Exception]] | None = ..., -) -> Sequence[U]: +) -> list[U]: ... @@ -66,7 +66,7 @@ def robust_map( max_workers: int | None = None, raise_error: bool = False, retry_exception_types: list[type[Exception]] | None = None, -) -> Sequence[U | V]: +) -> list[U | V]: """Maps a function to inputs using a threadpool. The map supports exception handling, retries with exponential backoff, and @@ -107,12 +107,18 @@ def robust_map( fn_with_backoff = tenacity.retry( retry=retry, wait=tenacity.wait_random_exponential(min=1, max=30), + before_sleep=tenacity.before_sleep_log( + logging.get_absl_logger(), logging.WARNING + ), )(fn) else: fn_with_backoff = tenacity.retry( retry=retry, wait=tenacity.wait_random_exponential(min=1, max=30), stop=tenacity.stop_after_attempt(max_retries + 1), + before_sleep=tenacity.before_sleep_log( + logging.get_absl_logger(), logging.WARNING + ), )(fn) if index_to_output is None: index_to_output = {}