From 5238ca31eed61d404a3ecdf134432086b91071c2 Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Tue, 28 Jan 2025 21:39:30 +0200 Subject: [PATCH] Fix max_concurrent instance check and error message (#1420) Co-authored-by: Andrew Ho --- torchdata/nodes/map.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 04430d894..fab6a0f6c 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -318,8 +318,8 @@ def __init__( self._mp_context = mp.get_context(self.multiprocessing_context) if max_concurrent is not None and num_workers > 0: - if not isinstance(max_concurrent, int) and max_concurrent > num_workers: - raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") + if isinstance(max_concurrent, int) and max_concurrent > num_workers: + raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!") self.max_concurrent = max_concurrent self.snapshot_frequency = snapshot_frequency self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None @@ -404,8 +404,8 @@ def __init__( self.method = method self.multiprocessing_context = multiprocessing_context if max_concurrent is not None and num_workers > 0: - if not isinstance(max_concurrent, int) and max_concurrent > num_workers: - raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") + if isinstance(max_concurrent, int) and max_concurrent > num_workers: + raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!") self.max_concurrent = max_concurrent self.snapshot_frequency = snapshot_frequency self.prebatch = prebatch