From fed8b71179f05471aa6a6100b5f605c23ee726da Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Fri, 3 May 2024 12:58:17 -0700 Subject: [PATCH] Replace manual progress printing with tqdm. PiperOrigin-RevId: 630473171 --- edward2/maps.py | 26 +++++--------------------- setup.py | 4 +++- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/edward2/maps.py b/edward2/maps.py index d82444a3..0d1185e4 100644 --- a/edward2/maps.py +++ b/edward2/maps.py @@ -17,11 +17,12 @@ import concurrent.futures import datetime -from typing import Callable, Literal, Sequence, TypeVar, overload +from typing import Callable, Literal, overload, Sequence, TypeVar from absl import logging import grpc import tenacity +import tqdm T = TypeVar('T') U = TypeVar('U') @@ -64,7 +65,6 @@ def robust_map( inputs: Sequence[T], error_output: V = None, index_to_output: dict[int, U | V] | None = None, - log_percent: float = 5, max_retries: int | None = None, max_workers: int | None = None, raise_error: bool = False, @@ -82,7 +82,6 @@ def robust_map( `max_retries`. index_to_output: Optional dictionary to be used to store intermediate results in-place. - log_percent: At every `log_percent` percent of items, log the progress. max_retries: The maximum number of times to retry each input. If None, then there is no limit. If limit, the output is set to `error_output` or an error is raised if `raise_error` is set to True. @@ -98,6 +97,7 @@ def robust_map( A list of items each of type U. They are the outputs of `fn` applied to the elements of `inputs`. """ + progress_bar = tqdm.tqdm(total=len(inputs), desc='robust_map') if retry_exception_types is None: retry_exception_types = [] retry_exception_types = retry_exception_types + [ @@ -124,7 +124,6 @@ def robust_map( num_inputs = len(inputs) logging.info('Found %s/%s existing examples.', num_existing, num_inputs) indices = [i for i in range(num_inputs) if i not in index_to_output.keys()] - log_steps = max(1, num_inputs * log_percent // 100) with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers ) as executor: @@ -137,6 +136,7 @@ def robust_map( try: output = future.result() index_to_output[index] = output + progress_bar.update(1) except tenacity.RetryError as e: if raise_error: logging.exception('Item %s exceeded max retries.', index) @@ -150,22 +150,6 @@ def robust_map( e, ) index_to_output[index] = error_output - num_so_far = len(index_to_output) - if num_so_far % log_steps == 0 or num_so_far == num_inputs: - end = datetime.datetime.now() - elapsed = end - start - num_completed = num_so_far - num_existing - avg_per_example = elapsed / num_completed - num_remaining = num_inputs - num_so_far - eta = avg_per_example * num_remaining - logging.info( - 'Completed %d/%d inputs. Elapsed time (started with %d inputs): %s.' - ' ETA: %s.', - num_so_far, - num_inputs, - num_existing, - elapsed, - eta, - ) + progress_bar.update(1) outputs = [index_to_output[i] for i in range(num_inputs)] return outputs diff --git a/setup.py b/setup.py index f490c5dc..2ba7f9e9 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,9 @@ url='http://github.com/google/edward2', license='Apache 2.0', packages=find_packages(), - install_requires=[], + install_requires=[ + 'tqdm', + ], extras_require={ 'jax': ['jax>=0.2.13', 'flax>=0.3.4'],