Skip to content

Commit

Permalink
Replace manual progress printing with tqdm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630473171
  • Loading branch information
dustinvtran authored and edward-bot committed May 3, 2024
1 parent 83607d9 commit 1769a63
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 26 deletions.
30 changes: 5 additions & 25 deletions edward2/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"""A better map."""

import concurrent.futures
import datetime
from typing import Callable, Literal, Sequence, TypeVar, overload
from typing import Callable, Literal, overload, Sequence, TypeVar

from absl import logging
import grpc
import tenacity
import tqdm

T = TypeVar('T')
U = TypeVar('U')
Expand All @@ -34,7 +34,6 @@ def robust_map(
inputs: Sequence[T],
error_output: V = ...,
index_to_output: dict[int, U | V] | None = ...,
log_percent: float = ...,
max_retries: int | None = ...,
max_workers: int | None = ...,
raise_error: Literal[False] = ...,
Expand All @@ -49,7 +48,6 @@ def robust_map(
inputs: Sequence[T],
error_output: V = ...,
index_to_output: dict[int, U | V] | None = ...,
log_percent: float = ...,
max_retries: int | None = ...,
max_workers: int | None = ...,
raise_error: Literal[True] = ...,
Expand All @@ -64,7 +62,6 @@ def robust_map(
inputs: Sequence[T],
error_output: V = None,
index_to_output: dict[int, U | V] | None = None,
log_percent: float = 5,
max_retries: int | None = None,
max_workers: int | None = None,
raise_error: bool = False,
Expand All @@ -82,7 +79,6 @@ def robust_map(
`max_retries`.
index_to_output: Optional dictionary to be used to store intermediate
results in-place.
log_percent: At every `log_percent` percent of items, log the progress.
max_retries: The maximum number of times to retry each input. If None, then
there is no limit. If limit, the output is set to `error_output` or an
error is raised if `raise_error` is set to True.
Expand Down Expand Up @@ -123,20 +119,20 @@ def robust_map(
num_existing = len(index_to_output)
num_inputs = len(inputs)
logging.info('Found %s/%s existing examples.', num_existing, num_inputs)
progress_bar = tqdm.tqdm(total=num_inputs - num_existing, desc='robust_map')
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
log_steps = max(1, num_inputs * log_percent // 100)
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
future_to_index = {
executor.submit(fn_with_backoff, inputs[i]): i for i in indices
}
start = datetime.datetime.now()
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
output = future.result()
index_to_output[index] = output
progress_bar.update(1)
except tenacity.RetryError as e:
if raise_error:
logging.exception('Item %s exceeded max retries.', index)
Expand All @@ -150,22 +146,6 @@ def robust_map(
e,
)
index_to_output[index] = error_output
num_so_far = len(index_to_output)
if num_so_far % log_steps == 0 or num_so_far == num_inputs:
end = datetime.datetime.now()
elapsed = end - start
num_completed = num_so_far - num_existing
avg_per_example = elapsed / num_completed
num_remaining = num_inputs - num_so_far
eta = avg_per_example * num_remaining
logging.info(
'Completed %d/%d inputs. Elapsed time (started with %d inputs): %s.'
' ETA: %s.',
num_so_far,
num_inputs,
num_existing,
elapsed,
eta,
)
progress_bar.update(1)
outputs = [index_to_output[i] for i in range(num_inputs)]
return outputs
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
url='http://github.com/google/edward2',
license='Apache 2.0',
packages=find_packages(),
install_requires=[],
install_requires=[
'tqdm',
],
extras_require={
'jax': ['jax>=0.2.13',
'flax>=0.3.4'],
Expand Down

0 comments on commit 1769a63

Please sign in to comment.