Skip to content

Commit

Permalink
Consistent errors for unused streams in batching methods (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Nov 11, 2024
1 parent 3f22e07 commit 193b6fc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
5 changes: 5 additions & 0 deletions streaming/base/batching/device_per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def generate_work_device_per_stream_batching(dataset: StreamingDataset, world: W
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
if shuffle_block_portion == 0:
raise ValueError(f'Samples from stream {stream_id} are not being used. Please ' +
f'either increase the `shuffle_block_size` from ' +
f'{dataset.shuffle_block_size}, or increase the stream ' +
f'proportion from {stream.proportion}.')
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
shuffle_block_portion)
Expand Down
6 changes: 3 additions & 3 deletions streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
if num_full_batches > 0:
batches_from_partitions.append(global_batches_inorder[:num_full_batches])
else:
logger.warning(f'Stream with index {stream_idx} does not have an adequate number of ' +
f'samples to construct a complete global batch. Training will occur ' +
f'without any samples from this stream!')
raise ValueError(f'Stream with index {stream_idx} does not have an adequate number ' +
f'of samples to construct a complete global batch. Training will ' +
f'occur without any samples from this stream.')

# Combine all global batches from all streams into one array.
all_partition_batches = np.concatenate(batches_from_partitions)
Expand Down

0 comments on commit 193b6fc

Please sign in to comment.