Skip to content

Commit

Permalink
Replace manual progress printing with tqdm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630473171
  • Loading branch information
dustinvtran authored and edward-bot committed May 3, 2024
1 parent 83607d9 commit fed8b71
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
26 changes: 5 additions & 21 deletions edward2/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import concurrent.futures
import datetime
from typing import Callable, Literal, Sequence, TypeVar, overload
from typing import Callable, Literal, overload, Sequence, TypeVar

from absl import logging
import grpc
import tenacity
import tqdm

T = TypeVar('T')
U = TypeVar('U')
Expand Down Expand Up @@ -64,7 +65,6 @@ def robust_map(
inputs: Sequence[T],
error_output: V = None,
index_to_output: dict[int, U | V] | None = None,
log_percent: float = 5,
max_retries: int | None = None,
max_workers: int | None = None,
raise_error: bool = False,
Expand All @@ -82,7 +82,6 @@ def robust_map(
`max_retries`.
index_to_output: Optional dictionary to be used to store intermediate
results in-place.
log_percent: At every `log_percent` percent of items, log the progress.
max_retries: The maximum number of times to retry each input. If None, then
there is no limit. If limit, the output is set to `error_output` or an
error is raised if `raise_error` is set to True.
Expand All @@ -98,6 +97,7 @@ def robust_map(
A list of items each of type U. They are the outputs of `fn` applied to
the elements of `inputs`.
"""
progress_bar = tqdm.tqdm(total=len(inputs), desc='robust_map')
if retry_exception_types is None:
retry_exception_types = []
retry_exception_types = retry_exception_types + [
Expand All @@ -124,7 +124,6 @@ def robust_map(
num_inputs = len(inputs)
logging.info('Found %s/%s existing examples.', num_existing, num_inputs)
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
log_steps = max(1, num_inputs * log_percent // 100)
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
Expand All @@ -137,6 +136,7 @@ def robust_map(
try:
output = future.result()
index_to_output[index] = output
progress_bar.update(1)
except tenacity.RetryError as e:
if raise_error:
logging.exception('Item %s exceeded max retries.', index)
Expand All @@ -150,22 +150,6 @@ def robust_map(
e,
)
index_to_output[index] = error_output
num_so_far = len(index_to_output)
if num_so_far % log_steps == 0 or num_so_far == num_inputs:
end = datetime.datetime.now()
elapsed = end - start
num_completed = num_so_far - num_existing
avg_per_example = elapsed / num_completed
num_remaining = num_inputs - num_so_far
eta = avg_per_example * num_remaining
logging.info(
'Completed %d/%d inputs. Elapsed time (started with %d inputs): %s.'
' ETA: %s.',
num_so_far,
num_inputs,
num_existing,
elapsed,
eta,
)
progress_bar.update(1)
outputs = [index_to_output[i] for i in range(num_inputs)]
return outputs
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
url='http://github.com/google/edward2',
license='Apache 2.0',
packages=find_packages(),
install_requires=[],
install_requires=[
'tqdm',
],
extras_require={
'jax': ['jax>=0.2.13',
'flax>=0.3.4'],
Expand Down

0 comments on commit fed8b71

Please sign in to comment.