From a14bd438628dd063bfaec7b23b6dd47c49bcf572 Mon Sep 17 00:00:00 2001 From: hanhui <193691140@qq.com> Date: Sun, 28 Jul 2024 02:02:47 +0800 Subject: [PATCH] fixed some edge case bugs --- benchmarks/benchmark_asyncload.py | 2 +- benchmarks/benckmark_streamload.py | 2 +- parquet_loader/dataloader.py | 11 ++-- parquet_loader/dataset.py | 95 +++++++++++++++++++----------- parquet_loader/reader.py | 17 ++++-- parquet_loader/shuffle.py | 8 +-- parquet_loader/utils.py | 71 ++++++++++++++-------- tests/test_associate.py | 5 +- tests/test_parallel.py | 5 +- tests/test_single.py | 3 +- 10 files changed, 138 insertions(+), 81 deletions(-) diff --git a/benchmarks/benchmark_asyncload.py b/benchmarks/benchmark_asyncload.py index 756af0a..34cda2a 100644 --- a/benchmarks/benchmark_asyncload.py +++ b/benchmarks/benchmark_asyncload.py @@ -8,7 +8,7 @@ from parquet_loader import ParquetDataset, ParquetDataLoader from parquet_loader.reader import AsyncParquetReader, SyncParquetReader -path = 'synthetic_data' +path = '../synthetic_data' delay_in_seconds = 0.01 diff --git a/benchmarks/benckmark_streamload.py b/benchmarks/benckmark_streamload.py index 70bce34..aa93340 100644 --- a/benchmarks/benckmark_streamload.py +++ b/benchmarks/benckmark_streamload.py @@ -10,7 +10,7 @@ from memory_profiler import profile as mem_profile ## config -path = 'synthetic_data' +path = '../synthetic_data' num_workers = 4 batch_size = 66 diff --git a/parquet_loader/dataloader.py b/parquet_loader/dataloader.py index 03df84d..86bfbb1 100644 --- a/parquet_loader/dataloader.py +++ b/parquet_loader/dataloader.py @@ -25,18 +25,19 @@ def squeeze_first_dim(data): else: return data + class _SingleProcessDataLoaderIter(TorchSingleProcessDataLoaderIter): def _next_data(self): data = super()._next_data() return squeeze_first_dim(data) - + + class _MultiProcessingDataLoaderIter(TorchMultiProcessingDataLoaderIter): def _next_data(self): data = super()._next_data() return squeeze_first_dim(data) - class ParquetDataLoader(DataLoader): def __init__( self, @@ -60,7 +61,7 @@ def __init__( ): dataset.set_shuffle(shuffle) dataset.set_drop_last(drop_last) - self._num_batches = dataset.get_num_batches(batch_size) + dataset.set_batch_size(batch_size) # reset arguments shuffle = False @@ -91,7 +92,9 @@ def __init__( def __len__(self) -> int: - return self._num_batches + if not hasattr(self, 'num_batches'): + self.num_batches = self.dataset.get_num_batches() + return self.num_batches def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: diff --git a/parquet_loader/dataset.py b/parquet_loader/dataset.py index 906eec4..8bc61bd 100644 --- a/parquet_loader/dataset.py +++ b/parquet_loader/dataset.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional import pandas as pd @@ -15,6 +16,9 @@ from .reader import SyncParquetReader, AsyncParquetReader +logger = logging.getLogger(__name__) + + class ParquetDataset(IterableDataset): def __init__( self, @@ -22,31 +26,20 @@ def __init__( column_names: Optional[List[str]] = None, async_read: bool = False, max_preload: int = 1 - ): - _ds = ds.dataset(path) - self.metas = [ - ParquetMetadata( - file_path=f.path, - num_rows=f.metadata.num_rows, - num_row_groups=f.metadata.num_row_groups, - num_rows_per_row_group=[ - f.metadata.row_group(i).num_rows - for i in range(f.metadata.num_row_groups) - ] - ) - for f in _ds.get_fragments() - ] - self.column_names = column_names or _ds.schema.names - del _ds - self.world_size, self.global_rank, self.num_nodes = detect_distributed_env() + ): + self.path = path + self.column_names = column_names self.async_read = async_read self.max_preload = max_preload - Reader = AsyncParquetReader if async_read and max_preload > 0 else SyncParquetReader - self.reader = Reader(self.max_preload) + self.reader = AsyncParquetReader(self.max_preload) \ + if async_read and max_preload > 0 else \ + SyncParquetReader(self.max_preload) + self.world_size, self.global_rank, self.num_nodes = detect_distributed_env() + self.batch_size = 1 def __len__(self): - if not getattr(self, 'num_rows', None): + if not hasattr(self, 'num_rows'): self._associate_to_workers() return self.num_rows @@ -56,18 +49,42 @@ def set_shuffle(self, shuffle: bool) -> None: def set_drop_last(self, drop_last: bool) -> None: self.drop_last = drop_last - - def get_num_batches(self, batch_size=1): - assert hasattr(self, 'drop_last'), 'call `set_drop_last` before call `get_num_batches`' + def set_batch_size(self, batch_size: int) -> None: self.batch_size = batch_size + + def get_num_batches(self, batch_size=None): + assert hasattr(self, 'drop_last'), 'call `set_drop_last` before call `get_num_batches`' + batch_size = batch_size or self.batch_size if self.drop_last: return len(self) // batch_size else: return (len(self) + batch_size - 1) // batch_size + + def _try_fetch_metadata(self): + if hasattr(self, 'metas'): + return + + _ds = ds.dataset(self.path) + self.metas = [ + ParquetMetadata( + file_path=f.path, + num_rows=f.metadata.num_rows, + num_row_groups=f.metadata.num_row_groups, + num_rows_per_row_group=[ + f.metadata.row_group(i).num_rows + for i in range(f.metadata.num_row_groups) + ] + ) + for f in _ds.get_fragments() + ] + self.column_names = self.column_names or _ds.schema.names + del _ds + def _associate_to_workers(self): + self._try_fetch_metadata() self.num_workers, self.worker_rank = detect_worker_env() self.intervals, self.num_rows = self.shuffler.associate_to_workers( metas=self.metas, @@ -88,16 +105,19 @@ def __iter__(self): def __getitem__(self, index: int) -> pd.DataFrame: + logger.warning("call `__getitem__` is inefficient, only for test usage.") + self._try_fetch_metadata() global_index = 0 - for itv in self.intervals: - offset = itv.local_row_end - itv.local_row_start - global_index += offset - if global_index > index: - f = pq.ParquetFile(self.metas[itv.file_index]) - table = f.read_row_group(itv.row_group_index).select(self.column_names) - f.close() - return table.slice(index - (global_index - offset), 1)\ - .to_pandas(split_blocks=True, self_destruct=True).to_numpy() + for itvs in self.intervals: + for itv in itvs: + offset = itv.local_row_end - itv.local_row_start + global_index += offset + if global_index > index: + f = pq.ParquetFile(self.metas[itv.file_index]) + table = f.read_row_group(itv.row_group_index).select(self.column_names) + f.close() + return table.slice(index - (global_index - offset), 1)\ + .to_pandas(split_blocks=True, self_destruct=True).to_numpy() def iter_batch(self): @@ -123,8 +143,11 @@ def iter_batch(self): tables = [table_left] # reset tables # last batch, it may not be full batch - table_left = pa.concat_tables(tables) - batch_data = table_left.slice(0, self.batch_size)\ - .to_pandas(split_blocks=True, self_destruct=True).to_numpy() - yield batch_data + if len(tables) > 0: + table_left = pa.concat_tables(tables) + if table_left.shape[0] == 0: + return + batch_data = table_left.slice(0, self.batch_size)\ + .to_pandas(split_blocks=True, self_destruct=True).to_numpy() + yield batch_data diff --git a/parquet_loader/reader.py b/parquet_loader/reader.py index 8ed7e9d..32a8299 100644 --- a/parquet_loader/reader.py +++ b/parquet_loader/reader.py @@ -16,7 +16,7 @@ class Reader: def __init__(self, max_preload: int = 1): self.max_preload = max_preload - def setup(self, metas: List[ParquetMetadata], intervals: Dict[int, List[RowGroupInterval]]): + def setup(self, metas: List[ParquetMetadata], intervals: List[List[RowGroupInterval]]): self.metas = metas self.intervals = intervals @@ -39,8 +39,8 @@ class SyncParquetReader(Reader): @property def table_iterator(self): - for fi, itvs in self.intervals.items(): - with self._open_parquet_file(self.metas[fi].file_path) as pf: + for itvs in self.intervals: + with self._open_parquet_file(self.metas[itvs[0].file_index].file_path) as pf: for itv in itvs: offset = itv.local_row_end - itv.local_row_start yield pf.read_row_group(itv.row_group_index).slice(itv.local_row_start, offset) @@ -52,10 +52,15 @@ class AsyncParquetReader(Reader): _END_TOKEN = "_END" _DEFAULT_TIMEOUT = 1 - def _preload(self, metas, intervals, queue): + def _preload( + self, + metas: List[ParquetMetadata], + intervals: List[List[RowGroupInterval]], + queue: Queue + ): try: - for fi, itvs in intervals.items(): - with self._open_parquet_file(metas[fi].file_path) as pf: + for itvs in intervals: + with self._open_parquet_file(metas[itvs[0].file_index].file_path) as pf: for itv in itvs: offset = itv.local_row_end - itv.local_row_start table = pf.read_row_group(itv.row_group_index).slice(itv.local_row_start, offset) diff --git a/parquet_loader/shuffle.py b/parquet_loader/shuffle.py index eca9935..54f2f83 100644 --- a/parquet_loader/shuffle.py +++ b/parquet_loader/shuffle.py @@ -22,7 +22,7 @@ def associate_to_workers( current_worker_rank: int = 0, drop_last: bool = False, batch_size: int = 1, - ) -> Tuple[Dict[int, List[RowGroupInterval]], int]: + ) -> Tuple[List[List[RowGroupInterval]], int]: raise NotImplementedError def shuffle(self, data: np.ndarray) -> np.ndarray: @@ -41,7 +41,7 @@ def associate_to_workers( current_worker_rank: int = 0, drop_last: bool = False, batch_size: int = 1, - ) -> Tuple[Dict[int, List[RowGroupInterval]], int]: + ) -> Tuple[List[List[RowGroupInterval]], int]: return associate_to_workers( metas=metas, @@ -68,7 +68,7 @@ def associate_to_workers( current_worker_rank: int = 0, drop_last: bool = False, batch_size: int = 1, - ) -> Tuple[Dict[int, List[RowGroupInterval]], int]: + ) -> Tuple[List[List[RowGroupInterval]], int]: # shuffle files metas = copy.deepcopy(metas) @@ -85,7 +85,7 @@ def associate_to_workers( ) # shuffle row groups - for fi, itvs in intervals.items(): + for itvs in intervals: self.rng.shuffle(itvs) return intervals, num_rows diff --git a/parquet_loader/utils.py b/parquet_loader/utils.py index d145f1c..6d4d627 100644 --- a/parquet_loader/utils.py +++ b/parquet_loader/utils.py @@ -1,4 +1,5 @@ import os +import copy from typing import List, Tuple, Dict import torch @@ -34,27 +35,32 @@ def associate_to_workers( current_worker_rank: int = 0, drop_last: bool = False, batch_size: int = 1, -) -> Tuple[Dict[int, List[RowGroupInterval]], int]: +) -> Tuple[List[List[RowGroupInterval]], int]: total_num_rows = sum([meta.num_rows for meta in metas]) + rank0_extra_rows = total_num_rows % world_size # rank 0 may get extra rows num_rows_per_ranks = [ - total_num_rows // world_size + total_num_rows % world_size # TODO: check its correctness - if rank == world_size - 1 and not drop_last + total_num_rows // world_size + rank0_extra_rows + if rank == 0 and not drop_last else total_num_rows // world_size for rank in range(world_size) ] + ratio = num_workers * batch_size if drop_last: - ratio = num_workers * batch_size - num_rows_per_ranks = [ratio * int(item // ratio) for item in num_rows_per_ranks] - num_rows_per_workers = np.array([ - [ - num_rows_per_ranks[rank] // num_workers + num_rows_per_ranks[rank] % num_workers - if worker_rank == 0 and not drop_last # worker 0 gets the remainder - else num_rows_per_ranks[rank] // num_workers - for worker_rank in range(num_workers) - ] - for rank in range(world_size) - ]) + num_rows_per_ranks = [ratio * int(rows // ratio) for rows in num_rows_per_ranks] + + num_rows_per_rank_worker = [] + for rank in range(world_size): + rank_extra_rows = num_rows_per_ranks[rank] % ratio + num_rows_per_worker = [] + for worker_rank in range(num_workers): + worker_num_rows = num_rows_per_ranks[rank] // ratio * batch_size + if rank_extra_rows > 0: + worker_num_rows += min(rank_extra_rows, batch_size) + rank_extra_rows -= batch_size + num_rows_per_worker.append(worker_num_rows) + num_rows_per_rank_worker.append(num_rows_per_worker) + num_rows_per_rank_worker = np.array(num_rows_per_rank_worker) row_group_intervals = [] global_rows = 0 @@ -68,11 +74,12 @@ def associate_to_workers( row_group_intervals.append(intervals_per_file) - start_row_index_per_workers = [0] + np.cumsum(num_rows_per_workers).tolist() - current_global_row_start = start_row_index_per_workers[current_rank * num_workers + current_worker_rank] - current_global_row_end = start_row_index_per_workers[current_rank * num_workers + current_worker_rank + 1] - current_intervals = {} + start_row_index_per_rank_worker = [0] + np.cumsum(num_rows_per_rank_worker).tolist() + current_global_row_start = start_row_index_per_rank_worker[current_rank * num_workers + current_worker_rank] + current_global_row_end = start_row_index_per_rank_worker[current_rank * num_workers + current_worker_rank + 1] + current_intervals = [] for intervals_per_file in row_group_intervals: + current_file_itvs = [] for itv in intervals_per_file: if itv.global_row_end <= current_global_row_start: continue @@ -88,13 +95,29 @@ def associate_to_workers( current_global_row_end < itv.global_row_end: itv.local_row_end = current_global_row_end - itv.global_row_start itv.global_row_end = current_global_row_end - - if itv.file_index not in current_intervals: - current_intervals[itv.file_index] = [itv] - else: - current_intervals[itv.file_index].append(itv) - + + current_file_itvs.append(itv) + if len(current_file_itvs) > 0: + current_intervals.append(current_file_itvs) + if not drop_last and rank0_extra_rows > 0: + for itvs in current_intervals: + current_file_itvs = [] + for itv in itvs: + offset = itv.local_row_end - itv.local_row_start + rank0_extra_rows -= offset + itv = copy.deepcopy(itv) + if rank0_extra_rows > 0: + current_file_itvs.append(itv) + else: + itv.local_row_end += rank0_extra_rows + itv.global_row_end += rank0_extra_rows + current_file_itvs.append(itv) + break + current_intervals.append(current_file_itvs) + if rank0_extra_rows <= 0: + break + return current_intervals, current_global_row_end-current_global_row_start diff --git a/tests/test_associate.py b/tests/test_associate.py index d71e6e7..3c00d1d 100644 --- a/tests/test_associate.py +++ b/tests/test_associate.py @@ -19,5 +19,6 @@ batch_size=11, ) print(num_rows) - for idx, itv in intervals.items(): - print(itv) \ No newline at end of file + for itvs in intervals: + for itv in itvs: + print(itv) \ No newline at end of file diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 8327c52..de367a5 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,7 +1,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from parquet_loader import ParquetDataset, ParquetDataLoader -parquet_path = 'synthetic_data' + +parquet_path = '../synthetic_data' world_size = 2 def run(rank): @@ -14,7 +15,7 @@ def run(rank): dataset = ParquetDataset(parquet_path, async_read=True) dataloader = ParquetDataLoader(dataset, batch_size=66, shuffle=True, num_workers=2) for i, batch in enumerate(dataloader): - # print(f"{rank}, {i}, {batch.shape}") + print(f"{rank}, {i}, {batch.shape}") dist.barrier() diff --git a/tests/test_single.py b/tests/test_single.py index 692bc9f..40fc381 100644 --- a/tests/test_single.py +++ b/tests/test_single.py @@ -1,7 +1,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from parquet_loader import ParquetDataset, ParquetDataLoader -parquet_path = 'synthetic_data' + +parquet_path = '../synthetic_data' if __name__ == '__main__': dataset = ParquetDataset(parquet_path, async_read=True)