From 1769a639c89a8212ed81560b7c8682a8b6215b64 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Fri, 3 May 2024 12:58:17 -0700 Subject: [PATCH] Replace manual progress printing with tqdm. PiperOrigin-RevId: 630473171 --- edward2/maps.py | 30 +++++------------------------- setup.py | 4 +++- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/edward2/maps.py b/edward2/maps.py index d82444a3..6dc09952 100644 --- a/edward2/maps.py +++ b/edward2/maps.py @@ -16,12 +16,12 @@ """A better map.""" import concurrent.futures -import datetime -from typing import Callable, Literal, Sequence, TypeVar, overload +from typing import Callable, Literal, overload, Sequence, TypeVar from absl import logging import grpc import tenacity +import tqdm T = TypeVar('T') U = TypeVar('U') @@ -34,7 +34,6 @@ def robust_map( inputs: Sequence[T], error_output: V = ..., index_to_output: dict[int, U | V] | None = ..., - log_percent: float = ..., max_retries: int | None = ..., max_workers: int | None = ..., raise_error: Literal[False] = ..., @@ -49,7 +48,6 @@ def robust_map( inputs: Sequence[T], error_output: V = ..., index_to_output: dict[int, U | V] | None = ..., - log_percent: float = ..., max_retries: int | None = ..., max_workers: int | None = ..., raise_error: Literal[True] = ..., @@ -64,7 +62,6 @@ def robust_map( inputs: Sequence[T], error_output: V = None, index_to_output: dict[int, U | V] | None = None, - log_percent: float = 5, max_retries: int | None = None, max_workers: int | None = None, raise_error: bool = False, @@ -82,7 +79,6 @@ def robust_map( `max_retries`. index_to_output: Optional dictionary to be used to store intermediate results in-place. - log_percent: At every `log_percent` percent of items, log the progress. max_retries: The maximum number of times to retry each input. If None, then there is no limit. If limit, the output is set to `error_output` or an error is raised if `raise_error` is set to True. @@ -123,20 +119,20 @@ def robust_map( num_existing = len(index_to_output) num_inputs = len(inputs) logging.info('Found %s/%s existing examples.', num_existing, num_inputs) + progress_bar = tqdm.tqdm(total=num_inputs - num_existing, desc='robust_map') indices = [i for i in range(num_inputs) if i not in index_to_output.keys()] - log_steps = max(1, num_inputs * log_percent // 100) with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers ) as executor: future_to_index = { executor.submit(fn_with_backoff, inputs[i]): i for i in indices } - start = datetime.datetime.now() for future in concurrent.futures.as_completed(future_to_index): index = future_to_index[future] try: output = future.result() index_to_output[index] = output + progress_bar.update(1) except tenacity.RetryError as e: if raise_error: logging.exception('Item %s exceeded max retries.', index) @@ -150,22 +146,6 @@ def robust_map( e, ) index_to_output[index] = error_output - num_so_far = len(index_to_output) - if num_so_far % log_steps == 0 or num_so_far == num_inputs: - end = datetime.datetime.now() - elapsed = end - start - num_completed = num_so_far - num_existing - avg_per_example = elapsed / num_completed - num_remaining = num_inputs - num_so_far - eta = avg_per_example * num_remaining - logging.info( - 'Completed %d/%d inputs. Elapsed time (started with %d inputs): %s.' - ' ETA: %s.', - num_so_far, - num_inputs, - num_existing, - elapsed, - eta, - ) + progress_bar.update(1) outputs = [index_to_output[i] for i in range(num_inputs)] return outputs diff --git a/setup.py b/setup.py index f490c5dc..2ba7f9e9 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,9 @@ url='http://github.com/google/edward2', license='Apache 2.0', packages=find_packages(), - install_requires=[], + install_requires=[ + 'tqdm', + ], extras_require={ 'jax': ['jax>=0.2.13', 'flax>=0.3.4'],